# Framework for finding temporal differences in ML-models trained in different timelines

### Installations

In [1]:
! pip install sentence-transformers openai keybert requests anytree tqdm
#sudo apt-get install graphviz


[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.1.2[0m[39;49m -> [0m[32;49m24.2[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


### Imports

In [21]:
from sentence_transformers import SentenceTransformer
import re
import openai
import torch
import logging
from nutree import Tree, Node
import json
from anytree import Node, RenderTree,find_by_attr,find,findall
from anytree.exporter import JsonExporter
from anytree.importer import JsonImporter
from tqdm import tqdm
from anytree.exporter import UniqueDotExporter, DotExporter
import heapq
import math
import time

### Global Settings

In [4]:
# Define sBert model which is used for all the comparing processes
model_sbert = SentenceTransformer('all-MiniLM-L6-v2')

# Create a custom logger
# Create a custom logger
logger = logging.getLogger("my_logger")

# Set the default log level
logger.setLevel(logging.DEBUG)

# Create a file handler to write logs to a file
file_handler = logging.FileHandler('app.log')

# Create a formatter and set it for the file handler
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
file_handler.setFormatter(formatter)

# Add the file handler to the logger
logger.addHandler(file_handler)

# Ensure the logger does not propagate messages to the root logger
logger.propagate = False

### Gemini

In [5]:
! pip install -q -U google-generativeai


[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.1.2[0m[39;49m -> [0m[32;49m24.2[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


In [6]:
import pathlib
import textwrap
import os
import google.generativeai as genai

from IPython.display import display
from IPython.display import Markdown
safety_settings = [
    {
        "category": "HARM_CATEGORY_DANGEROUS",
        "threshold": "BLOCK_NONE",
    },
    {
        "category": "HARM_CATEGORY_HARASSMENT",
        "threshold": "BLOCK_NONE",
    },
    {
        "category": "HARM_CATEGORY_HATE_SPEECH",
        "threshold": "BLOCK_NONE",
    },
    {
        "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
        "threshold": "BLOCK_NONE",
    },
    {
        "category": "HARM_CATEGORY_DANGEROUS_CONTENT",
        "threshold": "BLOCK_NONE",
    }
]
gemini_model = "gemini-1.5-flash"

def to_markdown(text):
  text = text.replace('•', '  *')
  return Markdown(textwrap.indent(text, '> ', predicate=lambda _: True))
GOOGLE_API_KEY="Your_key"
#GOOGLE_API_KEY=os.getenv('GOOGLE_API_KEY')
#print(GOOGLE_API_KEY)
genai.configure(api_key=GOOGLE_API_KEY)
model = genai.GenerativeModel(model_name=gemini_model, safety_settings=safety_settings)

class Response:
    def __init__(self, text):
        self.text = text

### Global functions

In [7]:
def get_new_parameters(num_samples, diversity, threshold_similarity_classes, threshold_skip_node_exploration, depth, decay_factor=0.1):
    """
    Adjusts parameters based on the depth of exploration, applying a decay factor to gradually reduce or increase the parameters.

    Parameters:
    - num_samples (int): The initial number of samples to generate.
    - diversity (int): The initial diversity factor for generating subcategories.
    - threshold_similarity_classes (float): The initial threshold for class similarity.
    - threshold_skip_node_exploration (float): The initial threshold for skipping node exploration.
    - depth (int): The current depth of exploration.
    - decay_factor (float, optional): The decay factor to adjust the parameters. Defaults to 0.1.

    Returns:
    - tuple: A tuple containing the adjusted num_samples, diversity, threshold_similarity_classes, and threshold_skip_node_exploration.
    """
    logger.info(f"Adjusting parameters for depth {depth} with decay factor {decay_factor}")

    # Adjust parameters based on depth and decay factor
    adjusted_num_samples = max(1, num_samples - (depth * decay_factor))
    adjusted_diversity = max(1, diversity - (depth * (decay_factor / 2.0)))
    adjusted_threshold_similarity_classes = min(1.0, threshold_similarity_classes + (depth * (decay_factor / 30.0)))
    adjusted_threshold_skip_node_exploration = min(1.0, threshold_skip_node_exploration + (depth * (decay_factor / 50.0)))

    logger.debug(f"Adjusted num_samples: {adjusted_num_samples}, diversity: {adjusted_diversity}, "
                  f"threshold_similarity_classes: {adjusted_threshold_similarity_classes}, "
                  f"threshold_skip_node_exploration: {adjusted_threshold_skip_node_exploration}")

    # Return the adjusted parameters, ensuring integer values where appropriate
    return (math.ceil(adjusted_num_samples), 
            math.ceil(adjusted_diversity), 
            adjusted_threshold_similarity_classes, 
            adjusted_threshold_skip_node_exploration)

## Comparison


### LLM comparator

In [8]:
import sys
import os
sys.path.append(os.path.abspath("./llm_comparator_custom/python/src"))

In [9]:

from llm_comparator import comparison
from llm_comparator import model_helper
from llm_comparator import llm_judge_runner
from llm_comparator import rationale_bullet_generator
from llm_comparator import rationale_cluster_generator
from llm_comparator import prompt_templates

In [10]:
def format_comparator_output(json_file):
    return {'llm_comparator': {
        'examples': json_file['examples'],
        'rationale_clusters': json_file['rationale_clusters']
    }
    }
def return_llm_comparator_data(sentence_old, sentence_new, topic_name):
    judge_opts = {'num_repeats': 4}
    clusterer_opts= {'num_clusters' : 4}
    llm_judge_inputs = [
    {'prompt': f'Tell me smething about {topic_name}.', 
     'response_a': str(sentence_old), 
     'response_b': str(sentence_new)
    }
    ]
    generator = model_helper.VertexGenerationModelHelper(model_name="gemini-1.5-flash")
    embedder = model_helper.VertexEmbeddingModelHelper()
    judge = llm_judge_runner.LLMJudgeRunner(generator)

    bulletizer = rationale_bullet_generator.RationaleBulletGenerator(generator)
    clusterer = rationale_cluster_generator.RationaleClusterGenerator(
    generator, embedder
    )

    result = comparison.run(
    llm_judge_inputs,
    judge,
    bulletizer,
    clusterer,
    judge_opts=judge_opts,
    clusterer_opts=clusterer_opts
    )
    result = format_comparator_output(result)
    return result['llm_comparator']['examples'][0]['score'] , result

### General

In [11]:
def format_data_to_json(input_text_category, old_data, new_data):
    """
    Classifies each sentence in two lists (old_data and new_data) into their respective categories 
    and organizes them in a nested dictionary structure.

    Args:
        input_text_category (str): The main category or topic.
        old_data (dict): A dictionary of categories with lists of sentences to classify in the "old" category.
        new_data (dict): A dictionary of categories with lists of sentences to classify in the "new" category.

    Returns:
        dict: A dictionary with the classified sentences organized by categories and sub-categories with the ("old" and "new") model predictions. Similarity 
        is the calculated simialrity of old vs new data facts.
              Structured like this:
              {
                  "Topic": {
                      "old": {
                          "data": []
                      },
                      "new": {
                          "data": []
                      },
                      "similarity": float
                  },
                  ...
              }
    """
    logger.info(f"Formatting data for category: {input_text_category}")

    data_branch = "data"
    similarity_branch = "similarity"
    result = {}
    default_similarity = 0.0

    # Process old data
    for category_old, facts in old_data.items():
        if category_old not in result:
            result[category_old] = {
                "old": {data_branch: []}, 
                "new": {data_branch: []}, 
                similarity_branch: default_similarity
            }
        for fact in facts:
            result[category_old]["old"][data_branch].append(fact)
        logger.debug(f"Added {len(facts)} facts to old category: {category_old}")

    # Process new data
    for category_new, facts in new_data.items():
        if category_new not in result:
            result[category_new] = {
                "old": {data_branch: []}, 
                "new": {data_branch: []}, 
                similarity_branch: default_similarity
            }
        for fact in facts:
            result[category_new]["new"][data_branch].append(fact)
        logger.debug(f"Added {len(facts)} facts to new category: {category_new}")

    logger.info("Data formatting complete")
    return {input_text_category: result}

def encode_sentence(sentence):
    return model_sbert.encode(sentence, convert_to_tensor=True)

def aggregate_classifications(json_data, root_topic_name, threshold):
    """
    Aggregates categories in the json_data based on their semantic similarity, using a Sentence-BERT model.

    Parameters:
    - json_data (dict): A dictionary containing categories with their corresponding data.
    - root_topic_name (str): The root topic name to exclude from aggregation.
    - threshold (float): The similarity threshold for aggregation.

    Returns:
    - dict: A dictionary with aggregated categories based on their semantic similarity.
    """
    logger.info("Starting aggregation of classifications")

    # Encode category names
    category_embeddings = {category: encode_sentence(category) for category in json_data.keys()}
    logger.debug(f"Encoded categories: {list(category_embeddings.keys())}")

    # Aggregate categories based on similarity
    aggregated_data = {}
    visited = set()

    for category1 in json_data.keys():
        if category1 in visited or root_topic_name == category1:
            continue
        
        aggregated_data[category1] = json_data[category1]
        aggregated_data[category1]["similarity"] = 0.0
        visited.add(category1)
        logger.info(f"Processing category: {category1}")

        for category2 in json_data.keys():
            if category2 in visited or category1 == category2 or root_topic_name == category1:
                continue
            
            similarity = model_sbert.similarity(category_embeddings[category1], category_embeddings[category2]).item()
            logger.debug(f"Similarity between {category1} and {category2}: {similarity}")

            if similarity >= threshold:
                logger.info(f"Aggregating {category2} into {category1} with similarity {similarity}")
                visited.add(category2)
                # Aggregate data
                aggregated_data[category1]["old"]["data"].extend(json_data[category2]["old"]["data"])
                aggregated_data[category1]["new"]["data"].extend(json_data[category2]["new"]["data"])
                aggregated_data[category1]["similarity"] = 0.0

    logger.info("Aggregation complete")
    return {root_topic_name: aggregated_data}

def check_common_classifications(category_list, category):
    """
    Checks if a very similar category got already generated in the tree and if so skip it. This prevents that the tree regenerates classifications with facts.

    Parameters:
    - category_list (list): A list of all so far generated categories.
    - category (str): The category name to check.

    Returns:
    - bool: True if the category is common, False otherwise.
    """
    common_classification_in_generation = 0.94
    category_embeddings_list = model_sbert.encode(category_list)
    category_embeddings_classification = model_sbert.encode(category)
    similarity = model_sbert.similarity(category_embeddings_classification, category_embeddings_list)
    result = (similarity > common_classification_in_generation).any().item()
    logger.debug(f"Category {category} is common: {result}")
    return result

def get_embeddings(list_old, list_new, model_sbert):
    logger.info("Encoding sentences into embeddings")
    embeddings1_old = model_sbert.encode(list_old, convert_to_tensor=True)
    embeddings2_new = model_sbert.encode(list_new, convert_to_tensor=True)
    return [embeddings1_old, embeddings2_new]

def calculate_average_similarity(matching_sentences):
    """
    Calculates the average similarity score of matching sentences.

    Parameters:
    - matching_sentences (list): A list of tuples containing pairs of sentences and their similarity scores.

    Returns:
    - float: The average similarity score.
    """
    similarity_scores = [score for _, _, score in matching_sentences]
    average_similarity = sum(similarity_scores) / len(similarity_scores) if similarity_scores else 0
    logger.info(f"Calculated average similarity: {average_similarity}")
    return average_similarity

def return_matching_sentences(list_old, list_new, model_sbert, similarity_type):
    """
    Returns matched sentences between two lists based on similarity scores.

    Parameters:
    - list_old: List of sentences from the old dataset.
    - list_new: List of sentences from the new dataset.
    - model_sbert: Sentence-BERT model to compute sentence embeddings.
    - similarity_type: Type of similarity computation ('facts' or 'paragraph_with_classification').

    Returns:
    - selected_pairs: List of matched sentence pairs with their similarity scores.
    - average_similarity: The average similarity score of the selected pairs.
    """
    
    # If similarity type is paragraph, treat lists as joined paragraphs
    if similarity_type == 'paragraph_with_classification':
        list_old = " ".join(list_old)
        list_new = " ".join(list_new)
    
    # Compute embeddings for both sentence lists
    embeddings_old_new = get_embeddings(list_old, list_new, model_sbert)
    cosine_scores = model_sbert.similarity(embeddings_old_new[0], embeddings_old_new[1])
    
    # Special handling for paragraph similarity type. Just return both paragraphs since there are only two to match.
    if similarity_type == 'paragraph_with_classification':
        return [(list_old, list_new, cosine_scores.item())], cosine_scores.item()

    # Flatten similarity scores and prepare index pairs for sorting
    scores = cosine_scores.flatten()
    indices = torch.cartesian_prod(torch.arange(cosine_scores.size(0)), torch.arange(cosine_scores.size(1)))

    # Ensure indices and scores are on the same device
    if scores.device != indices.device:
        indices = indices.to(scores.device)

    # Sort scores and corresponding index pairs in descending order
    sorted_indices = torch.argsort(scores, descending=True)
    sorted_scores = scores[sorted_indices]
    sorted_pairs = indices[sorted_indices]

    # Initialize sets to track matched sentences and list to store selected pairs
    matched_sentences_old = set()
    matched_sentences_new = set()
    selected_pairs = []

    # Select pairs ensuring each sentence is matched at least once
    for (i, j), score in zip(sorted_pairs, sorted_scores):
        if i.item() not in matched_sentences_old or j.item() not in matched_sentences_new:
            selected_pairs.append((list_old[i.item()], list_new[j.item()], score.item()))
            matched_sentences_old.add(i.item())
            matched_sentences_new.add(j.item())
    
    # Return selected pairs along with their average similarity
    return selected_pairs, calculate_average_similarity(selected_pairs)


def update_similarity_nodes(node_list, num_samples, year_old_start, year_old_end, year_new_start, year_new_end, similarity_type, improve_tree=0):
    """
    Updates similarity nodes by generating new sentences for underrepresented time periods 
    and computing similarity scores between old and new sentences.

    Parameters:
    - node_list: List of nodes to update.
    - num_samples: Number of samples to generate for missing sentences.
    - year_old_start: Start year for the old period.
    - year_old_end: End year for the old period.
    - year_new_start: Start year for the new period.
    - year_new_end: End year for the new period.
    - similarity_type: Type of similarity computation ('facts_with_classification', 'paragraph_with_classification', 'llm_comparator').
    - improve_tree: Flag indicating whether to improve the tree or not (this is used to make changes or adapt the tree after generation) (default is 0 for no).

    Returns:
    - Updates the node data with new sentences and similarity scores.
    """

    # Determine generation type based on similarity type
    generation_type = 'facts' if similarity_type == 'facts_with_classification' else 'paragraph'

    with tqdm(total=len(node_list)) as sBar:
        for node in node_list:
            # Handle case where old sentences are fewer than new sentences
            if len(node.data.sentences_old) < len(node.data.sentences_new) and not improve_tree:
                user_prompt = get_parent_node_names(node)
                generated_question = " -> ".join(user_prompt)
                predictions_old, _, _ = predict_with_gemini(
                    input_text_category=generated_question,
                    start_year=year_old_start,
                    end_year=year_old_end,
                    generation_type=generation_type,
                    num_samples=min(len(node.data.sentences_new) - len(node.data.sentences_old), num_samples)
                )
                node.data.sentences_old.extend(predictions_old)
                _update_node_similarity(node, similarity_type)

            # Handle case where new sentences are fewer than old sentences
            elif len(node.data.sentences_new) < len(node.data.sentences_old) and not improve_tree:
                user_prompt = get_parent_node_names(node)
                generated_question = " ".join(user_prompt)
                predictions_new, _, _ = predict_with_gemini(
                    input_text_category=generated_question,
                    start_year=year_new_start,
                    end_year=year_new_end,
                    generation_type='facts',
                    num_samples=min(len(node.data.sentences_old) - len(node.data.sentences_new), num_samples)
                )
                node.data.sentences_new.extend(predictions_new)
                _update_node_similarity(node, similarity_type)

            # Handle case where sentence counts are equal or improve_tree is enabled
            elif (len(node.data.sentences_old) == len(node.data.sentences_new) and node.data.sentences_old and node.data.sentences_new) or improve_tree:
                _update_node_similarity(node, similarity_type)

            # Handle any other case, setting a default similarity score
            else:
                update_node_data(node, new_similarity_score=5.0)

            sBar.update(1)

def _update_node_similarity(node, similarity_type):
    """
    Helper function to update the node's similarity score based on the specified similarity type.
    
    Parameters:
    - node: The node whose similarity score is being updated.
    - similarity_type: The type of similarity to compute ('llm_comparator' or others).
    """
    
    if similarity_type != 'llm_comparator':
        most_matching_sentence_list, avg_score = return_matching_sentences(
            node.data.sentences_old, 
            node.data.sentences_new, 
            model_sbert, 
            similarity_type
        )
        update_node_data(node, new_matching_sentences=most_matching_sentence_list, new_similarity_score=avg_score * 1.1)
    else:
        score, llm_comparator_data = return_llm_comparator_data(
            node.data.sentences_old, 
            node.data.sentences_new, 
            node.data.category_name
        )
        update_node_data(node, new_similarity_score=score * 1.1, new_llm_comparator=llm_comparator_data)


## Model Predictions logic

In [12]:
def create_prompt(generation_type,tokens_per_fact, num_samples, input_question,start_year,end_year,subcategory_count=None):
    """
    Generates system and user prompts for an AI assistant to create historical facts based on specified criteria.

    Parameters:
    - generation_type (str): The type of fact generation. It can be either 'facts' which generates facts to one classification  or 'facts_with_classification' 
    which generates different classifications with facts.
    - tokens_per_fact (int): The maximum number of tokens allowed per fact.
    - num_samples (int): The number of facts to generate.
    - input_question (str): The input hierarchical chain of categories or a keyword.
    - start_year (int): The starting year for the contextual consistency of the facts.
    - end_year (int): The ending year for the contextual consistency of the facts.
    - subcategory_count (int, optional): The number of different subcategories to generate. Only used if generation_type is 'facts_with_classification'.

    Returns:
    - tuple: A tuple containing the system prompt and user prompt strings.
    """
    if generation_type == 'facts':
        system_prompt = f'''
        You are a fact-generating assistant specializing in creating historical facts for different categories.
        Task Description:
        The input is a hierarchical chain of categories starting from a general category and moving towards more specific ones. 
        Your task is to generate interesting historical facts for the last category in the provided chain.

        Requirements:
        1. Input Structure:
            - A hierarchical chain of categories (RootCategory -> SubRootCategory -> SubSubRootCategory -> ...).

        2. Output Structure:
            - Generate {num_samples} unique and interesting historical facts for the last category in the chain.
        
        3. Contextual Consistency:
            - Generate facts as if you were living between the years {start_year} and {end_year}.
            - Ensure that the facts align with the knowledge, culture, and technological advancements of that time period.
            - Avoid using information or events that occurred after {end_year}.
            - Please provide a formal, fact-based response with a professional tone, avoiding casual language and ensuring accuracy in the information presented.
        
        Output Format:
        1. <fact text>
        2. <fact text>
        ...
        {num_samples}. <fact text>
        '''
        user_prompt = f"Input: {input_question}\n\nOutput:"
        return system_prompt,user_prompt
    
    elif generation_type == 'facts_with_classification':
        system_prompt = f'''
        You are a fact-generating assistant specializing in creating historical facts for different categories.

        Task Description:
        The input can be just a keyword or a hierarchical chain of categories starting from a general category and moving towards more specific ones. 
        Your task is to generate a number of up to {subcategory_count} different categories and {num_samples} corresponding different facts. 
        For a provided chain or keyword, generate up to {subcategory_count} different subcategories at the end of the chain or for the keyword and provide {num_samples} interesting 
        historical facts for each of these subcategories.

        Requirements:
        1. Input Structure:
            - A hierarchical chain of categories (RootCategory -> SubRootCategory -> SubSubRootCategory -> ... ).

        2. Output description:
            - For a chain or keyword, generate up to {subcategory_count} subcategories branching from the last category in the chain or for the keyword.
            - For each subcategory, generate {num_samples} unique and interesting historical facts which each should be at maximum {tokens_per_fact} tokens long.

        3. Contextual Consistency:
            - Generate facts as if you were living between the years {start_year} and {end_year}.
            - Ensure that the facts align with the knowledge, culture, and technological advancements of that time period.
            - Avoid using information or events that occurred after {end_year}.
            - Please provide a formal, fact-based response with a professional tone, avoiding casual language and ensuring accuracy in the information presented.

        4. Output Format:

        1. Category: <Generated Category 1>
        - <1. Fact text>
        - <2. Fact text>
        ...
        - <{num_samples}. Fact text>

        2. Category: <Generated Category 2>
        - <Fact text 1>
        - <Fact text 2>
        ...
        - <Fact text {num_samples}>

        ...    

        {subcategory_count}. Category: <Generated Category {subcategory_count}>
        - <Fact text 1>
        - <Fact text 2>
        ...
        - <Fact text {num_samples}>
        '''
        user_prompt = f"Input: {input_question}\n\nOutput:"
        return system_prompt,user_prompt
    elif generation_type == 'paragraph_with_classification':
        system_prompt = f'''
        You are a text generating assistant specializing in creating historical text for different categories.

        Task Description:
        The input can be either a single keyword or a hierarchical chain of categories, 
        starting from a general category and progressing to more specific ones. Your task is to generate up to {subcategory_count} distinct 
        categories and write a corresponding text paragraph for each. Based on the provided chain or keyword, generate up to {subcategory_count} 
        subcategories at the end of the chain or for the keyword, and provide an engaging and precise description for each subcategory, 
        offering a concise summary."

        Requirements:
        1. Input Structure:
            - A hierarchical chain of categories (RootCategory -> SubRootCategory -> SubSubRootCategory -> ...).

        2. Output description:
            - For a chain or keyword, generate up to {subcategory_count} subcategories branching from the last category in the chain or for the keyword.
            - For each subcategory, generate a unique, diverse and interesting summary which should be about {tokens_per_fact * num_samples} tokens long.

        3. Contextual Consistency:
            - Generate text as if you were living between the years {start_year} and {end_year}.
            - Ensure that the text aligns with the knowledge, culture, and technological advancements of that time period.
            - Avoid using information or events that occurred after {end_year}.
            - Please provide a formal, fact-based response with a professional tone, avoiding casual language and ensuring accuracy in the information presented.

        4. Output Format:
        
        1. Category: <Generated Category 1>
        Text: <generated text 1>

        2. Category: <Generated Category 2>
        Text: <generated text 2>

        ...    

        {subcategory_count}. Category: <Generated Category {subcategory_count}>
        Text: <generated text {subcategory_count}>
        '''
        user_prompt = f"Input: {input_question}\n\nOutput:"
        return system_prompt,user_prompt

    elif generation_type == 'paragraph':
        system_prompt = f'''
        You are a text assistant specializing in creating text for different categories.
        Task Description:
        The input is a hierarchical chain of categories starting from a general category and moving towards more specific ones. 
        Your task is to generate interesting text for the last category in the provided chain.

        Requirements:
        1. Input Structure:
            - A hierarchical chain of categories (RootCategory -> SubRootCategory -> SubSubRootCategory -> ...).

        2. Output Structure:
            - Generate an interesting unique text for the last category in the chain.
        
        3. Contextual Consistency:
            - Generate the text as if you were living between the years {start_year} and {end_year}.
            - Ensure that the text align with the knowledge, culture, and technological advancements of that time period.
            - Avoid using information or events that occurred after {end_year}.
            - Please provide a formal, fact-based response with a professional tone, avoiding casual language and ensuring accuracy in the information presented.
        
        Output Format:
        Text: <generated Text>
        '''
        user_prompt = f"Input: {input_question}\n\nOutput:"
        return system_prompt,user_prompt

def extract_output(generation_type, response, num_samples, question):
    """
    Extracts generated facts (with categories) from the response based on the generation type.

    Parameters:
    - generation_type (str): Type of fact generation ('facts', 'paragraph', 'facts_with_classification', or 'paragraph_with_classification').
    - response (object): The response object containing the generated text.
    - num_samples (int): The number of facts to extract.
    - question (str): The input question or hierarchical chain of categories.

    Returns:
    - tuple: A tuple containing the extracted facts or categorized facts, the response object, and a boolean indicating if a second round is needed.
    """
    
    facts = []
    second_round = False
    output_text = response.text
    logger.info(f"Extracting output for generation type: {generation_type}")
    logger.debug(f"Output text: {output_text}")

    # Handle different generation types
    if generation_type == 'facts':
        facts = _extract_facts_from_lines(output_text, num_samples)
    
    elif generation_type == 'paragraph':
        facts = _extract_paragraphs(output_text)
    
    elif generation_type == 'facts_with_classification':
        categorized_facts, second_round = _extract_facts_with_classification(output_text, question)
        return categorized_facts, response, second_round

    elif generation_type == 'paragraph_with_classification':
        categorized_facts, second_round = _extract_paragraphs_with_classification(output_text)
        return categorized_facts, response, second_round

    return facts, response, second_round


def _extract_facts_from_lines(output_text, num_samples):
    """
    Extracts facts from the text based on numbered bullet points.
    
    Parameters:
    - output_text (str): The generated text.
    - num_samples (int): Number of samples to extract.

    Returns:
    - list: A list of extracted facts.
    """
    facts = []
    lines = output_text.split('\n')
    for line in lines:
        line = line.replace('**', '')
        if any(line.startswith(f"{i + 1}.") for i in range(num_samples)):
            fact = line.split('.', 1)[1].strip()
            facts.append(fact)
    logger.info(f"Extracted {len(facts)} facts")
    return facts


def _extract_paragraphs(output_text):
    """
    Extracts paragraphs from the text.

    Parameters:
    - output_text (str): The generated text.

    Returns:
    - list: A list of extracted paragraphs.
    """
    facts = []
    lines = output_text.split('\n')
    for line in lines:
        line = line.replace('**', '')
        if re.match(r'.*Text:(.+) | .*1.(.+)', line):
            fact_text = re.sub(r'.*Text:| *1.', '', line).strip()
            facts.append(fact_text)
            logger.info(f"Extracted fact: {fact_text}")
    logger.info(f"Extracted {len(facts)} paragraphs")
    return facts


def _extract_facts_with_classification(output_text, question):
    """
    Extracts facts categorized by category or subcategory.
    
    Parameters:
    - output_text (str): The generated text.
    - question (str): The input question or hierarchical chain of categories.

    Returns:
    - tuple: A tuple containing the categorized facts and a boolean indicating if a second round is needed.
    """
    categorized_facts = {}
    current_category = None
    patterns_cat = r'.*Category:|.*Subcategory:'
    patterns_fact = r'- Fact \d+:|\d+\.|\s*-'
    second_round = False

    for line in output_text.split('\n'):
        line = line.replace('**', '')
        if re.match(patterns_cat, line):
            current_category = re.sub(patterns_cat, '', line).strip()
            categorized_facts[current_category] = []
            logger.info(f"New category detected: {current_category}")
        elif current_category and re.match(patterns_fact, line):
            fact = re.sub(patterns_fact, '', line).strip()
            categorized_facts[current_category].append(fact)

    logger.info(f"Extracted categories: {list(categorized_facts.keys())}")
    if len(categorized_facts) == 1:
        second_round = _check_second_round(categorized_facts, question)
    
    return categorized_facts, second_round


def _extract_paragraphs_with_classification(output_text):
    """
    Extracts paragraphs categorized by category or subcategory.

    Parameters:
    - output_text (str): The generated text.

    Returns:
    - tuple: A tuple containing the categorized facts and a boolean indicating if a second round is needed.
    """
    categorized_facts = {}
    current_category = None
    patterns_cat = r'.*Category:|.*Subcategory:'
    patterns_fact = r'Text:|\d+\.|.*Text:'
    second_round = False

    for line in output_text.split('\n'):
        line = line.replace('**', '')
        if re.match(patterns_cat, line):
            current_category = re.sub(patterns_cat, '', line).strip()
            categorized_facts[current_category] = []
            logger.info(f"New category detected: {current_category}")
        elif current_category and re.match(patterns_fact, line):
            fact = re.sub(patterns_fact, '', line).strip()
            categorized_facts[current_category].append(fact)

    return categorized_facts, second_round


def _check_second_round(categorized_facts, question):
    """
    Checks if a second round of fact extraction is needed.

    Parameters:
    - categorized_facts (dict): Dictionary of categorized facts.
    - question (str): The input question or hierarchical chain of categories.

    Returns:
    - bool: True if a second round is needed, False otherwise.
    """
    generated_category = list(categorized_facts.keys())[0]
    if len(generated_category) > 120:
        raise ValueError("Category extraction didn't work!")
    
    last_arrow_index = question.rfind('->')
    if last_arrow_index != -1:
        result = question[last_arrow_index + 3:].strip()
        if result in generated_category:
            logger.info("Detected need for a second round of generation")
            return True
    return False


def predict_with_gemini(
    input_text_category: str,
    start_year: str,
    end_year: str,
    generation_type: str,
    second_round: bool = False,
    previous_output: str = "",
    num_samples: int = 3,
    temperature: float = 0.7,
    tokens_per_fact: int = 50,
    logprobs: bool = False,
    subcategory_count: int = 3
):
    """
    Generates predictions based on a specified time range and input text category using the Gemini model.

    Parameters:
    - input_text_category (str): The category of input text.
    - start_year (str): The starting year for the prediction.
    - end_year (str): The ending year for the prediction.
    - generation_type (str): The type of generation (e.g., 'facts', 'paragraph').
    - second_round (bool): Whether this is a second-round generation to improve subcategories.
    - previous_output (str): The output from the first round, used in the second round.
    - num_samples (int): The number of facts to generate.
    - temperature (float): The sampling temperature.
    - tokens_per_fact (int): The number of tokens to allocate per fact.
    - logprobs (bool): Whether to log probabilities of the output.
    - subcategory_count (int): The number of subcategories to generate.

    Returns:
    - The extracted output from the model's response.
    """
    
    global global_ml_counter, start_time

    max_retries = 10  # Maximum number of retries
    initial_wait = 5  # Initial wait time in seconds
    retries = 0
    wait_time = initial_wait
    
    # Create the system and user prompts
    system_prompt, user_prompt = create_prompt(
        generation_type, tokens_per_fact, num_samples, input_text_category, start_year, end_year, subcategory_count
    )
    
    max_tokens = tokens_per_fact * num_samples * subcategory_count

    # Prepare messages for first or second round of generation
    messages = _prepare_generation_messages(system_prompt, user_prompt, second_round, previous_output)
    
    logger.info(f"{'Second' if second_round else 'First'} round of generation")
    logger.info(f"Asking question for {start_year} to {end_year}: {user_prompt}")

    # Retry loop with exponential backoff
    while retries < max_retries:
        try:
            response = _generate_content_with_gemini(messages, max_tokens)
            return _process_response(response, generation_type, num_samples, input_text_category)
        
        except Exception as e:
            logger.warning(f"Attempt {retries + 1} of {max_retries} failed. Retrying in {wait_time} seconds... {e}")
            time.sleep(wait_time)
            retries += 1
            wait_time *= 2  # Exponential backoff
    
    #global_ml_counter += 1
    #_handle_rate_limiting()
    
    return extracted_output


def _prepare_generation_messages(system_prompt, user_prompt, second_round, previous_output):
    """
    Prepares the messages for the Gemini model based on the round of generation.
    
    Parameters:
    - system_prompt (str): The system prompt.
    - user_prompt (str): The user prompt.
    - second_round (bool): Whether this is a second-round generation.
    - previous_output (str): The output from the first round.

    Returns:
    - list: A list of message dictionaries.
    """
    if second_round:
        second_round_question = (
            "You didn't generate different subcategories! You just used the last item of the chain "
            "and generated facts to it. I want subcategories like described in the task before. Try again!"
        )
        return [
            {"role": "model", "parts": [system_prompt]},
            {"role": "user", "parts": [user_prompt]},
            {"role": "model", "parts": [previous_output]},
            {"role": "user", "parts": [second_round_question]}
        ]
    
    return [
        {"role": "model", "parts": [system_prompt]},
        {"role": "user", "parts": [user_prompt]}
    ]


def _generate_content_with_gemini(messages, max_tokens):
    """
    Sends a request to the Gemini model to generate content.

    Parameters:
    - messages (list): The messages to send to the model.
    - max_tokens (int): Maximum tokens for the output.

    Returns:
    - response (object): The model's response.
    """
    return model.generate_content(
        messages,
        generation_config=genai.types.GenerationConfig(
            candidate_count=1,
            max_output_tokens=max_tokens,
            top_p=0.1
        )
    )


def _process_response(response, generation_type, num_samples, input_text_category):
    """
    Processes the model's response and extracts output.

    Parameters:
    - response (object): The model's response.
    - generation_type (str): The type of generation (e.g., 'facts', 'paragraph').
    - num_samples (int): The number of samples to extract.
    - input_text_category (str): The input text category.

    Returns:
    - The extracted output.
    """
    logger.debug(f"Generated output: \n{response} with finish reason: {response.candidates[0].finish_reason}")
    
    if response.candidates[0].finish_reason == 4:
        logger.warning(f"No Candidate found for question.")
        response_message = 'No Information found about this topic!'
        return extract_output(generation_type, Response(response_message), num_samples, input_text_category)
    
    return extract_output(generation_type, response, num_samples, input_text_category)

#global_ml_counter = 0

# def _handle_rate_limiting():
#     """
#     Handles rate limiting by sleeping if necessary based on global ML counter and elapsed time.
#     """
#     global global_ml_counter, start_time

#     if global_ml_counter % 15 == 0:
#         end_time = time.time()
#         elapsed_time = end_time - start_time
#         minutes_elapsed = elapsed_time / 60

#         if minutes_elapsed < (global_ml_counter / 15):
#             logger.info("Rate limit hit, sleeping...")
#             time.sleep((60 * math.ceil(minutes_elapsed)) - elapsed_time)



## Logic for building up the classification tree

In [13]:
class NodeData:
    def __init__(self, category_name, sentences_new, sentences_old, similarity_score, matching_sentences=None, llm_comparator=None):
        self.category_name = category_name
        self.sentences_new = sentences_new
        self.sentences_old = sentences_old
        self.similarity_score = similarity_score
        self.matching_sentences = matching_sentences or []
        self.llm_comparator = llm_comparator or []

    def __repr__(self):
        return (f"NodeData(category_name={self.category_name}, "
                f"sentences_new={self.sentences_new}, "
                f"sentences_old={self.sentences_old}, "
                f"similarity_score={self.similarity_score}, "
                f"matching_sentences={self.matching_sentences}, "
                f"llm_comparator={self.llm_comparator})")
    
def serialize_mapper(node):
    return {
        "name": node.name,
        "category_name": node.data.category_name,
        "sentences_new": node.data.sentences_new,
        "sentences_old": node.data.sentences_old,
        "similarity_score": node.data.similarity_score,
        "matching_sentences": node.data.matching_sentences,
        "llm_comparator": node.data.llm_comparator
    }


def deserialize_mapper(data):
    node_data = NodeData(
        category_name=data["category_name"],
        sentences_new=data["sentences_new"],
        sentences_old=data["sentences_old"],
        similarity_score=data["similarity_score"],
        matching_sentences=data.get("matching_sentences", []),
        llm_comparator=data.get("llm_comparator", [])
    )
    return Node(data["name"], data=node_data)

class CustomJsonExporter(JsonExporter):
    def __init__(self):
        super().__init__()

    def write(self, node, file):
        def node_to_dict(node):
            return {
                "name": node.name,
                "data": {
                    "category_name": node.data.category_name,
                    "sentences_new": node.data.sentences_new,
                    "sentences_old": node.data.sentences_old,
                    "similarity_score": node.data.similarity_score,
                    "matching_sentences": node.data.matching_sentences,
                    "llm_comparator": node.data.llm_comparator
                } if hasattr(node, 'data') else None,
                "children": [node_to_dict(child) for child in node.children]
            }

        json.dump(node_to_dict(node), file, indent=4)

class CustomJsonImporter(JsonImporter):
    def __init__(self):
        super().__init__()

    def read(self, file):
        def dict_to_node(data, parent=None):
            node_data = None
            if data["data"] is not None:
                node_data = NodeData(
                    category_name=data["data"]["category_name"],
                    sentences_new=data["data"]["sentences_new"],
                    sentences_old=data["data"]["sentences_old"],
                    similarity_score=data["data"]["similarity_score"],
                    matching_sentences=data["data"].get("matching_sentences", []),
                    llm_comparator=data["data"].get("llm_comparator", [])
                )
            node = Node(data["name"], parent=parent, data=node_data)
            for child_data in data["children"]:
                dict_to_node(child_data, parent=node)
            return node

        return dict_to_node(json.load(file))

def add_entries_to_tree(parent_node, category_name, sentences_old, sentences_new, similarity_score):
    node_data = NodeData(
        category_name=category_name,
        sentences_old=sentences_old,
        sentences_new=sentences_new,
        similarity_score=similarity_score
    )
    Node(category_name, parent=parent_node, data=node_data)

def add_tree_category_layer(parent_node, json_data, topic, category_list):
    for category in json_data[topic]:
        logger.info(f"Adding category layer for topic {topic}")
        if category_list and check_common_classifications(category_list, category):
            logger.info(f"Skipped common classification {category} in the tree.")
            continue

        category_data = json_data[topic][category]
        add_entries_to_tree(parent_node, category, category_data["old"]["data"], category_data["new"]["data"], category_data["similarity"])
        category_list.append(category)
        logger.debug(f"Added category {category} to tree")


def update_node_data(node, new_sentences_new=None, new_sentences_old=None, new_similarity_score=None, new_matching_sentences=None, new_llm_comparator = None):
    if new_sentences_new is not None:
        node.data.sentences_new.extend(new_sentences_new) 
    if new_sentences_old is not None:
        node.data.sentences_old.extend(new_sentences_old)
    if new_similarity_score is not None:
        node.data.similarity_score = new_similarity_score
    if new_matching_sentences is not None:
        node.data.matching_sentences = new_matching_sentences
    if new_llm_comparator is not None:
        node.data.llm_comparator = new_llm_comparator

### evaluation methods for tree

In [14]:
def save_tree_json(tree, filename):
    """
    Saves the tree structure to a JSON file.

    Parameters:
    - tree (Node): The root node of the tree.
    - filename (str): The filename to save the tree to.
    """
    logger.info(f"Saving tree to {filename}")
    exporter = CustomJsonExporter()
    with open(filename, "w") as f:
        exporter.write(tree, f)

def load_tree_json(filename):
    """
    Loads the tree structure from a JSON file.

    Parameters:
    - filename (str): The filename to load the tree from.

    Returns:
    - Node: The root node of the tree.
    """
    logger.info(f"Loading tree from {filename}")
    importer = CustomJsonImporter()
    with open(filename, "r") as f:
        tree = importer.read(f)
    return tree

def print_tree(tree=None, filename=None):
    """
    Generates a graphical representation of the tree and saves it as a PNG file.

    Parameters:
    - tree (Node, optional): The root node of the tree. Required if filename is not provided.
    - filename (str, optional): The filename to load the tree from. Required if tree is not provided.
    """
    logger.info("Printing tree")
    
    def nodeattrfunc(node):
        label = node.name
        if hasattr(node, 'data') and node.data:
            similarity_score = '-' if node.data.similarity_score is None else f"{node.data.similarity_score:.3f}"
            label += f"\nScore: {similarity_score}"
        return f'label="{label}"'

    if filename:
        logger.info(f"Loading tree from file: {filename}")
        loaded_tree = load_tree_json(filename)
        UniqueDotExporter(loaded_tree, nodeattrfunc=nodeattrfunc).to_picture("tree.png")
    else:
        logger.info(f"Generating tree picture for: {tree.name}")
        UniqueDotExporter(tree, nodeattrfunc=nodeattrfunc).to_picture(f"{tree.name}.png")

def find_lowest_similarity_nodes(root, top_n=10):
    """
    Finds the nodes with the lowest similarity scores in the tree.

    Parameters:
    - root (Node): The root node of the tree.
    - top_n (int, optional): The number of nodes with the lowest similarity scores to retrieve. Defaults to 10.

    Returns:
    - list: A list of nodes with the lowest similarity scores.
    """
    logger.info(f"Finding the top {top_n} nodes with the lowest similarity scores")
    heap = []

    def traverse(node):
        if hasattr(node, 'data') and node.data.similarity_score is not None:
            heapq.heappush(heap, (-node.data.similarity_score, node))
            if len(heap) > top_n:
                heapq.heappop(heap)
        for child in node.children:
            traverse(child)

    traverse(root)

    lowest_similarity_nodes = [heapq.heappop(heap)[1] for _ in range(len(heap))]
    lowest_similarity_nodes.reverse()

    logger.debug(f"Lowest similarity nodes: {[node.name for node in lowest_similarity_nodes]}")
    return lowest_similarity_nodes

### tree operations

In [15]:
def find_first_node_by_category_name(tree, category_name):
    logger.info(f"Finding first node by category name: {category_name}")
    return find_by_attr(tree, name=category_name)

def get_parent_node_names(node):
    """
    Retrieves the names of all parent nodes from a given node up to the root.

    Parameters:
    - node (Node): The node from which to start retrieving parent names.

    Returns:
    - list: A list of parent node names in hierarchical order from root to the given node.
    """
    logger.info(f"Getting parent node names for node: {node.name}")
    parent_names = [node.name]
    current_node = node
    while current_node.parent is not None:
        current_node = current_node.parent
        parent_names.append(current_node.name)
    logger.debug(f"Parent names: {parent_names[::-1]}")
    return parent_names[::-1]

def get_leaf_nodes(root, current_depth):
    """
    Retrieves the leaf nodes at a specified depth from the root node.

    Parameters:
    - root (Node): The root node of the tree.
    - current_depth (int): The current depth to retrieve leaf nodes from.

    Returns:
    - list: A list of leaf nodes at the specified depth.
    """
    logger.info(f"Getting leaf nodes at depth: {current_depth + 1}")
    depth_offset = 1
    depth = depth_offset + current_depth
    leaf_nodes = [node for node in root.descendants if node.is_leaf and node.depth == depth]
    logger.debug(f"Found {len(leaf_nodes)} leaf nodes at depth {depth}")
    return leaf_nodes

## Main Loop of algorithm

In [27]:
global_ml_counter = 0

In [28]:
'''
Settings for the prediction
'''
# Settings for prediction
year_old_start, year_old_end = "1990", "1995"
year_new_start, year_new_end = "2018", "2023"
generation_type = 'paragraph_with_classification'  # Can be 'paragraph_with_classification' or 'facts_with_classification'
similarity_type = 'llm_comparator'
num_samples = 10
subcategory_count = 3
tokens_per_fact = 50
name_for_json_output_file = "test.json" # here is the name of json which has the resulting tree
iteration_depth_for_tree = 2
root_topic_name = 'public health crisis' # the start topic name
threshold_similarity_classes = 0.93
threshold_skip_node_exploration = 0.75

# Initialize lists
node_list = []
category_list = []

# Start time for performance tracking
start_time = time.time()

for depth in range(iteration_depth_for_tree):
    if depth == 0:
        # Initial generation for the root topic
        generated_question = root_topic_name
        
        predictions_old, response_old, check_output = predict_with_gemini(
            input_text_category=generated_question, 
            start_year=year_old_start, 
            end_year=year_old_end,
            generation_type=generation_type,
            num_samples=num_samples, 
            subcategory_count=subcategory_count, 
            tokens_per_fact=tokens_per_fact
        )
        
        predictions_new, response_new, check_output = predict_with_gemini(
            input_text_category=generated_question, 
            start_year=year_new_start, 
            end_year=year_new_end,
            generation_type=generation_type,
            num_samples=num_samples, 
            subcategory_count=subcategory_count, 
            tokens_per_fact=tokens_per_fact
        )
        
        # Format and aggregate data
        json_data = format_data_to_json(generated_question, predictions_old, predictions_new)
        aggregate_class_json = aggregate_classifications(json_data[root_topic_name], root_topic_name, threshold_similarity_classes)

        # Create tree root
        node_data = NodeData(
            category_name=root_topic_name,
            sentences_old=[],
            sentences_new=[],
            similarity_score=1.0
        )
        tree = Node(root_topic_name, data=node_data)

        # Add categories to tree
        add_tree_category_layer(tree, aggregate_class_json, root_topic_name, category_list)
        node_list = get_leaf_nodes(tree, depth)
        
        # Update similarity for nodes
        update_similarity_nodes(node_list, num_samples, year_old_start, year_old_end, year_new_start, year_new_end, similarity_type)
        continue

    # Loop through existing nodes and expand the tree
    with tqdm(total=len(node_list), desc="Big loop queue") as pbar:
        for node in node_list:
            if node.data.similarity_score > threshold_skip_node_exploration:
                logger.info(f"Skipped node {node.name}")
                pbar.update(1)
                continue

            # Generate question for current node
            user_prompt = get_parent_node_names(node)
            generated_question = " -> ".join(user_prompt)
            logger.info(f"Generated question: {generated_question}")

            # Predictions for old and new data
            predictions_old, response_old, check_output = predict_with_gemini(
                input_text_category=generated_question, 
                start_year=year_old_start, 
                end_year=year_old_end, 
                generation_type=generation_type, 
                num_samples=num_samples
            )
            
            # Retry if the first prediction is bad
            if check_output:
                predictions_old, response_old, check_output = predict_with_gemini(
                    input_text_category=generated_question, 
                    start_year=year_old_start, 
                    end_year=year_old_end, 
                    generation_type=generation_type, 
                    second_round=True, 
                    previous_output=response_old.text, 
                    num_samples=num_samples
                )

            predictions_new, response_new, check_output = predict_with_gemini(
                input_text_category=generated_question, 
                start_year=year_new_start, 
                end_year=year_new_end, 
                generation_type=generation_type, 
                num_samples=num_samples
            )

            # Retry for new predictions if needed
            if check_output:
                predictions_new, response_new, check_output = predict_with_gemini(
                    input_text_category=generated_question, 
                    start_year=year_new_start, 
                    end_year=year_new_end, 
                    generation_type=generation_type, 
                    second_round=True, 
                    previous_output=response_new.text, 
                    num_samples=num_samples
                )

            # Format and aggregate data for the current node
            json_data = format_data_to_json(node.name, predictions_old, predictions_new)
            aggregate_class_json = aggregate_classifications(json_data[node.name], root_topic_name, threshold_similarity_classes)

            # Add new categories to the tree
            add_tree_category_layer(node, aggregate_class_json, root_topic_name, category_list)
            pbar.update(1)

    # Update parameters after each depth level
    num_samples, subcategory_count, threshold_similarity_classes, threshold_skip_node_exploration = get_new_parameters(
        num_samples, subcategory_count, threshold_similarity_classes, threshold_skip_node_exploration, depth
    )
    node_list = get_leaf_nodes(tree, depth)

    # Update similarity for nodes at the current depth
    update_similarity_nodes(node_list, num_samples, year_old_start, year_old_end, year_new_start, year_new_end, similarity_type)
    print(f"Layer {depth} completed!")

# Save the final tree structure
print(RenderTree(tree))
save_tree_json(tree, name_for_json_output_file)

I0000 00:00:1727953010.123894   87839 check_gcp_environment.cc:61] BIOS data file does not exist or cannot be opened.
Batches: 100%|██████████| 1/1 [00:00<00:00,  6.89it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 141.35it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 125.00it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 139.91it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 203.28it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 125.48it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 123.94it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 77.09it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 77.04it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 111.60it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 67.85it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 103.20it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 72.31it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 99.19it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 62.90it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 97.06

KeyboardInterrupt: 

# Visualization Section after the generation of the tree

## Functionallitites needed

In [None]:
! pip install plotly nbformat igraph dash

In [17]:
# Define the path to the tree which should be used
tree = load_tree_json("./sample_trees/tree_seaworld.json")
output_plot_name = "seaworld"
# The plot size may be customized by the user in the function: fig.update_layout for witdh and height

In [18]:
from statistics import mean
# Function to find all paths from root to leaves
def get_paths_to_leaves(node):
    if not node.children:
        return [[node]]
    paths = []
    for child in node.children:
        for path in get_paths_to_leaves(child):
            paths.append([node] + path)
    return paths

# Function to calculate average similarity score for every path
def calculate_avg_similarity_scores(root):
    paths = get_paths_to_leaves(root)
    path_similarity_scores = []

    for path in paths:
        similarity_scores = [node.data.similarity_score for node in path if node.data.similarity_score is not None]
        if similarity_scores:
            avg_score = mean(similarity_scores)
            path_similarity_scores.append((path, avg_score))
    
    return path_similarity_scores

def get_lowest_similarity_pahts(path_similarity_scores):
    sorted_path_scores = sorted(path_similarity_scores, key=lambda x: x[1])
    return sorted_path_scores[:3]

# Function to print the paths and their average similarity scores
def print_paths_and_scores(path_similarity_scores):
    sorted_path_scores = sorted(path_similarity_scores, key=lambda x: x[1])
    for path, avg_score in sorted_path_scores:
        path_str = " -> ".join([node.name for node in path])
        print(f"Path: {path_str}\nAverage Similarity Score: {avg_score:.4f}\n")

In [19]:
from anytree import Node
import igraph as ig
import plotly.graph_objects as go

def tree_to_igraph(root):
    g = ig.Graph(directed=True)
    nodes = {}
    
    def add_edges(node):
        if node.name not in nodes:
            if node.data.similarity_score == 0.0:
                node.data.similarity_score = 1.0
            nodes[node.name] = g.add_vertex(name=node.name, label=node.data.similarity_score)
        for child in node.children:
            if child.name not in nodes:
                if child.data.similarity_score == 0.0:
                    child.data.similarity_score = 1.0
                nodes[child.name] = g.add_vertex(name=child.name, label=child.data.similarity_score)
            g.add_edge(nodes[node.name], nodes[child.name])
            add_edges(child)
    
    add_edges(root)
    return g, nodes

G, nodes = tree_to_igraph(tree)

def node_to_dict(node_name):
    global tree
    #node = findall(tree, filter_=lambda node: node.name in (node_name))[0]
    node = find(tree,lambda node: node.name == node_name)
    
    
    node_dict = {
        "name": node.name,
        "data": {
            "category_name": node.data.category_name,
            "sentences_new": node.data.sentences_new,
            "sentences_old": node.data.sentences_old,
            "similarity_score": node.data.similarity_score,
            "matching_sentences": [
                {
                    "old_sentence": match[0],
                    "new_sentence": match[1],
                    "score": match[2]
                }
                for match in node.data.matching_sentences
            ],
            "llm_comparator": node.data.llm_comparator
        }
    }
    tree_json = json.dumps(node_dict, indent=4)
    return tree_json

## 1. Create the interactive plot with score per node coloriung

In [22]:
import igraph as ig
import plotly.graph_objects as go
from dash import Dash, html, dcc, Output, Input
import pandas as pd
from dash import dcc
import plotly.io as pio  # Import Plotly's IO module
# Compute layout for the graph

def create_plot(G, name):
    lay = G.layout('rt')  # Radial tree layout

    # Extract node and edge positions
    position = {k: lay[k] for k in range(len(lay))}
    Y = [lay[k][1] for k in range(len(lay))]
    M = max(Y)

    # Extract edges
    es = ig.EdgeSeq(G)
    E = [e.tuple for e in G.es]

    # Prepare coordinates for plotting
    Xn = [position[k][0] for k in range(len(position))]
    Yn = [2 * M - position[k][1] for k in range(len(position))]

    Xe = []
    Ye = []
    for edge in E:
        Xe += [position[edge[0]][0], position[edge[1]][0], None]
        Ye += [2 * M - position[edge[0]][1], 2 * M - position[edge[1]][1], None]

    # Extract similarity scores from labels and prepare hover texts
    similarity_scores = [v['label'] for v in G.vs]  # Similarity scores are stored in the label attribute
    hover_texts = [f"Node: {G.vs[i]['name']}<br>Score: {similarity_scores[i]:.2f}" for i in range(len(G.vs))]

    # Define a color scale
    color_scale = 'reds'  # Red to white color scale

    # Create Plotly traces
    fig = go.Figure()

    # Add edges
    fig.add_trace(go.Scatter(
        x=Xe,
        y=Ye,
        mode='lines',
        line=dict(color='rgb(0,0,0)', width=1),
        hoverinfo='none',
        opacity=1
    ))

    # Add nodes with color gradient based on similarity score
    fig.add_trace(go.Scatter(
        x=Xn,
        y=Yn,
        mode='markers',
        name='Nodes',
        marker=dict(
            symbol='circle-dot',
            size=10,  # Adjust dot size as needed
            color=similarity_scores,  # Color nodes based on similarity score
            colorscale=color_scale,  # Color scale for similarity score
            colorbar=dict(
                title='Similarity Score',
                tickvals=[min(similarity_scores), max(similarity_scores)],
                ticktext=[f"{min(similarity_scores):.2f}", f"{max(similarity_scores):.2f}"]
            ),
            line=dict(color='rgb(50,50,50)', width=0.2)
        ),
        text=hover_texts,
        hoverinfo='text',
        opacity=1,
        customdata=[v['name'] for v in G.vs],  # Add custom data for callback
    ))

    # Customize layout
    fig.update_layout(
        showlegend=False,
        xaxis=dict(showgrid=False, zeroline=False),
        yaxis=dict(showgrid=False, zeroline=False),
        width=4000,  # Width of the plot in pixels
        height=700   # Height of the plot in pixels
    )
    pio.write_html(fig, file=output_plot_name+'_score_colouring_tree.html', auto_open=True)
    return fig

# Create a Dash application
app = Dash(__name__)

# Layout of the Dash application
app.layout = html.Div([
    dcc.Graph(id='interactive-plot', figure=create_plot(G, 'example')),
    html.Div(id='node-info', style={'whiteSpace': 'pre-line', 'margin-top': '20px'}),
])

# Callback to update the text area when a node is clicked
@app.callback(
    Output('node-info', 'children'),
    [Input('interactive-plot', 'clickData')]
)
def display_node_info(clickData):
    
    if clickData is None:
        return "Click on a node to see more information."
    
    node_name = clickData['points'][0]['customdata']
    data = node_to_dict(node_name)
    # Load the content from a text file corresponding to the node name

    # Format JSON data using Markdown for better display
    formatted_data = f"```json\n{data}\n```"

    # Return a Markdown component with the formatted JSON data
    return dcc.Markdown(formatted_data)

if __name__ == '__main__':
    app.run_server(debug=True)


## 2. Creating plot whith aggregating score coloring 

In [23]:
from anytree import Node
import igraph as ig
import plotly.graph_objects as go
import json
from dash import Dash, html, dcc, Output, Input
import numpy as np

# Function to perform bottom-up aggregation of scores and store in a dictionary
def aggregate_scores(node):
    if node.is_leaf:
        return node.data.similarity_score
    else:
        # Sum the children's scores
        children_scores = sum(aggregate_scores(child) for child in node.children)
        # Divide by the number of children to get the average (if there are children)
        avg_children_score = children_scores / len(node.children) if node.children else 0
        # Aggregate score as node's score plus average of children's scores
        return node.data.similarity_score + avg_children_score

def get_aggregated_scores(root):
    scores = {}
    for node in root.descendants + (root,):  # Include root in descendants
        scores[node.name] = aggregate_scores(node)
    return scores


# Calculate aggregated scores
aggregated_scores = get_aggregated_scores(tree)

# Function to convert anytree to igraph
def tree_to_igraph(root, scores):
    g = ig.Graph(directed=True)
    nodes = {}
    
    def add_edges(node):
        if node.name not in nodes:
            nodes[node.name] = g.add_vertex(
                name=node.name, 
                label=scores[node.name],  # Use aggregated score for coloring
                original_label=node.data.similarity_score  # Store original similarity score for hover text
            )
        for child in node.children:
            if child.name not in nodes:
                nodes[child.name] = g.add_vertex(
                    name=child.name, 
                    label=scores[child.name],  # Use aggregated score for coloring
                    original_label=child.data.similarity_score  # Store original similarity score for hover text
                )
            g.add_edge(nodes[node.name], nodes[child.name])
            add_edges(child)
    
    add_edges(root)
    return g, nodes

# Convert the tree to an igraph graph
G, nodes = tree_to_igraph(tree, aggregated_scores)

# Function to convert node information to dictionary format
def node_to_dict(node_name):
    global tree
    node = find(tree, lambda node: node.name == node_name)
    
    node_dict = {
        "name": node.name,
        "data": {
            "category_name": node.data.category_name,
            "sentences_new": node.data.sentences_new,
            "sentences_old": node.data.sentences_old,
            "similarity_score": node.data.similarity_score,
            "matching_sentences": [
                {
                    "old_sentence": match[0],
                    "new_sentence": match[1],
                    "score": match[2]
                }
                for match in node.data.matching_sentences
            ],
            "llm_comparator": node.data.llm_comparator
        }
    }
    tree_json = json.dumps(node_dict, indent=4)
    return tree_json

# Compute layout for the graph
def create_plot(G, name):
    lay = G.layout('rt')  # Radial tree layout

    # Extract node and edge positions
    position = {k: lay[k] for k in range(len(lay))}
    Y = [lay[k][1] for k in range(len(lay))]
    M = max(Y)

    # Extract edges
    es = ig.EdgeSeq(G)
    E = [e.tuple for e in G.es]

    # Prepare coordinates for plotting
    Xn = [position[k][0] for k in range(len(position))]
    Yn = [2 * M - position[k][1] for k in range(len(position))]

    Xe = []
    Ye = []
    for edge in E:
        Xe += [position[edge[0]][0], position[edge[1]][0], None]
        Ye += [2 * M - position[edge[0]][1], 2 * M - position[edge[1]][1], None]

    # Extract aggregated scores for coloring
    aggregated_scores = np.array([v['label'] for v in G.vs])
    log_aggregated_scores = np.log1p(aggregated_scores)  # Apply log scaling for coloring

    # Extract original similarity scores for hover text
    original_scores = [v['original_label'] for v in G.vs]
    hover_texts = [f"Node: {G.vs[i]['name']}<br>Original Score: {original_scores[i]:.2f}" for i in range(len(G.vs))]

    # Define a color scale
    color_scale = 'reds'  # Red to white color scale

    # Create Plotly traces
    fig = go.Figure()

    # Add edges
    fig.add_trace(go.Scatter(
        x=Xe,
        y=Ye,
        mode='lines',
        line=dict(color='rgb(0,0,0)', width=1),
        hoverinfo='none',
        opacity=1
    ))

    # Add nodes with color gradient based on log-scaled aggregated score
    fig.add_trace(go.Scatter(
        x=Xn,
        y=Yn,
        mode='markers',
        name='Nodes',
        marker=dict(
            symbol='circle-dot',
            size=10,  # Adjust dot size as needed
            color=log_aggregated_scores,  # Color nodes based on log-scaled aggregated score
            colorscale=color_scale,  # Color scale for log-scaled aggregated score
            colorbar=dict(
                title='Aggregated Score',
                tickvals=[min(log_aggregated_scores), max(log_aggregated_scores)],
                ticktext=[f"{min(log_aggregated_scores):.2f}", f"{max(log_aggregated_scores):.2f}"]
            ),
            line=dict(color='rgb(50,50,50)', width=0.2)
        ),
        text=hover_texts,
        hoverinfo='text',
        opacity=1,
        customdata=[v['name'] for v in G.vs],  # Add custom data for callback
    ))

    # Customize layout
    fig.update_layout(
        showlegend=False,
        xaxis=dict(showgrid=False, zeroline=False),
        yaxis=dict(showgrid=False, zeroline=False),
        width=5000,  # Width of the plot in pixels
        height=500   # Height of the plot in pixels
    )
    pio.write_html(fig, file=output_plot_name + '_aggregation_colouring_tree_text.html', auto_open=True)
    return fig

# Create a Dash application
app = Dash(__name__)

# Layout of the Dash application
app.layout = html.Div([
    dcc.Graph(id='interactive-plot', figure=create_plot(G, 'example')),
    html.Div(id='node-info', style={'whiteSpace': 'pre-line', 'margin-top': '20px'}),
])

# Callback to update the text area when a node is clicked
@app.callback(
    Output('node-info', 'children'),
    [Input('interactive-plot', 'clickData')]
)
def display_node_info(clickData):
    if clickData is None:
        return "Click on a node to see more information."
    
    node_name = clickData['points'][0]['customdata']
    data = node_to_dict(node_name)
    # Format JSON data using Markdown for better display
    formatted_data = f"```json\n{data}\n```"

    return dcc.Markdown(formatted_data)

if __name__ == '__main__':
    app.run_server(debug=True)


## 3. create a treemap
Displaying Reasoning is not implemented here yet. Use the diagrams above to see it.

In [26]:
import plotly.express as px
import pandas as pd
import plotly.graph_objects as go

def insert_line_breaks(text, max_length=100):
    return "<br>".join([text[i:i+max_length] for i in range(0, len(text), max_length)])

# Function to extract data from anytree nodes
def get_tree_paths(node, node_name_list=[], parent_list=[], similarity_score_list=[], new=[], old=[]):
    node_name_list.append(node.name)
    if node.parent is None:
        parent_list.append("Topic")
    else:
        parent_list.append(node.parent.name)
    similarity_score_list.append(node.data.similarity_score)
    
    # Apply line break insertion to each sentence in the new and old lists
    new.append("<br>".join([insert_line_breaks(sentence) for sentence in node.data.sentences_new]))
    old.append("<br>".join([insert_line_breaks(sentence) for sentence in node.data.sentences_old]))
    
    for child in node.children:
        get_tree_paths(child, node_name_list, parent_list, similarity_score_list, new, old)
    
    return node_name_list, parent_list, similarity_score_list, new, old

# Extract data from the tree
node_name_list, parent_list, similarity_score_list, new, old = get_tree_paths(tree)

# Create the treemap
fig = px.treemap(
    names=node_name_list,
    parents=parent_list,
    values=similarity_score_list,
    color=similarity_score_list,
    hover_data={'New Sentences': new, 'Old Sentences': old},
    color_continuous_scale='Reds',
    #color_continuous_midpoint=np.average(similarity_score_list)
)

# Update the hover template to include the new and old sentences with line breaks
fig.update_traces(
    hovertemplate=(
        "<b>%{label}</b><br>"
        "<b>Score:</b> %{value}<br>"
        "<b>New:</b> %{customdata[0]}<br>\n"
        "<b>Old:</b> %{customdata[1]}<extra></extra>"
    )
)

# Adjust the layout for better visualization

fig.update_layout(width=1800, height=1000, margin=dict(t=50, l=25, r=25, b=25))
fig.write_html(output_plot_name+"_score_colouring_treemap_text.html")
# Show the plot
fig.show()
