# Insights
The insights notebook is our main tool to visualize syfter results.
1. Edit the **study_names** and **focus_study** values to analyze at the end of the notebook 
    * (either edit the regex filters or put in specific study names)
2. Run all cells in the notebook (click "Run All"), everything should just work
3. It will display all charts in the notebook 
4. It also writes a PDF in the **/results** folder that you can share

In [None]:
%pip install -qq "dataframe_image>=0.2.7" diskcache adjustText seaborn statsmodels altair[all]
%reload_ext autoreload
%autoreload 2

In [None]:
import os

if not os.getcwd().endswith("syftr"):
    os.chdir(os.path.dirname(os.getcwd()))
    print(f"Changed working directory to: {os.getcwd()}")

In [None]:
import re
import warnings

import numpy as np
import pandas as pd
from IPython.display import Markdown, display
from matplotlib.backends.backend_pdf import PdfPages

from syftr.configuration import cfg
from syftr.optuna_helper import get_study_names
from syftr.plotting.insights import ( # noqa
    CACHE, accuracy_plot,
    all_parameters_all_studies_plot,
    append_figure, append_table, compute_plot,
    compute_trial_rate_plot, cost_plot,
    create_benchmark_plot_and_table,
    create_exceptions_table, descriptive_name,
    focus_over_time_plot, get_name,
    latency_plot, load_studies,
    param_pareto_plot, param_plot,
    param_plot_all_studies, pareto_area_plot,
    pareto_comparison_plot,
    pareto_plot_and_table, plot_all_paretos,
    plot_metric_variability,
    plot_retriever_pareto,
    slowest_components_plot,
    study_similarity_all_pairs_plot,
    study_similarity_all_pairs_scatters_plot,
    study_similarity_plot,
    trial_duration_hist
)

pd.set_option("mode.chained_assignment", "raise")
pd.set_option("display.max_rows", 200)
warnings.simplefilter(action="error", category=FutureWarning)
np.seterr(over='ignore')


async def create_pdf_report(study_name, all_study_names=None, pdf_filename=None, titles=None, insights_prefix=None):
    # load all studies
    df, study_stats_table, exceptions_table = load_studies(all_study_names)
    df_study = df[df["study_name"] == study_name]

    # show most interesting parameters first then append others
    param_cols = [
        (
            "params_rag_mode",
            "params_llm_name",
            "params_template_name",
        ),
        "params_rag_mode",
        "params_llm_name",
        "params_template_name",
        "params_rag_method",
        "params_rag_embedding_model",
        "params_splitter_method",
        "params_splitter_chunk_exp",
        "params_splitter_chunk_overlap_frac",
        "params_rag_top_k",
        "params_reranker_enabled",
        "params_hyde_enabled",
    ]
    param_cols += [
        c
        for c in df_study.columns
        if c.startswith("params_") and c not in param_cols
    ]
    def valid_param_set(cols, df):
        if isinstance(cols, str):
            cols = [cols]
        for c in cols:
            if c not in df.columns:
                return False
            if df[c].nunique(dropna=False) <= 1:
                return False
        return True
    param_cols = [c for c in param_cols if valid_param_set(c, df)]

    # start writing a pdf report
    if pdf_filename is None:
        cfg.paths.results_dir.mkdir(parents=True, exist_ok=True)
        path = cfg.paths.results_dir.resolve()
        pdf_filename = str(path / f"insights_{study_name}.pdf")
    with PdfPages(pdf_filename) as pdf:
        display(Markdown(f"# Benchmark Summary"))

        # study stats table
        await append_table(pdf, study_stats_table, title="Study Stats")

        # exceptions table
        table = create_exceptions_table(df, exceptions_table)
        await append_table(pdf, table, title="Top Exceptions")

        # error here if the focus study isn't valid
        if len(df_study) == 0:
            raise ValueError(f"The focus study '{study_name}' contains 0 trials.")

        # benchmark accuracy table
        fig, table = create_benchmark_plot_and_table(df)
        if fig is not None:
            append_figure(pdf, fig, insights_prefix)
        if table is not None:
            await append_table(pdf, table, title="Benchmark Performance")

        # metric variability chart
        fig = plot_metric_variability(df, study_name)
        if fig is not None:
            append_figure(pdf, fig, insights_prefix)

        display(Markdown(f"# Pareto Frontier"))

        # all pareto fronts
        if "retriever-only" in study_name:
            fig, title = plot_retriever_pareto(df, study_name, titles=titles)
        else:
            fig, title = plot_all_paretos(df, all_study_names, titles=titles)
        append_figure(pdf, fig, insights_prefix, title=title)

        # pareto front with descriptions
        fig, table, fig_title = pareto_plot_and_table(df, study_name, titles=titles)
        append_figure(pdf, fig, insights_prefix, title=fig_title)
        await append_table(pdf, table, title=f"Pareto Frontier ({get_name(study_name, titles)})")

        display(Markdown(f"# Optimization Progress"))

        # optimization focus over time
        fig = focus_over_time_plot(df, study_name, titles=titles)
        append_figure(pdf, fig, insights_prefix)

        # compute trial rate plot
        fig = compute_trial_rate_plot(df, study_name, titles=titles)
        append_figure(pdf, fig, insights_prefix)

        # compute plot
        fig = compute_plot(df, study_name, titles=titles)
        append_figure(pdf, fig, insights_prefix)

        # cost plot
        fig = cost_plot(df, study_name, titles=titles)
        append_figure(pdf, fig, insights_prefix)
        
        # hist of trials
        fig = trial_duration_hist(df, study_name, titles=titles)
        append_figure(pdf, fig, insights_prefix)

        # historical pareto plots
        fig = pareto_comparison_plot(df, study_name, titles=titles)
        if fig is not None:
            append_figure(pdf, fig, insights_prefix)

        # pareto area plot
        fig = pareto_area_plot(df, study_name, titles=titles)
        append_figure(pdf, fig, insights_prefix)

        # accuracy plot
        fig = accuracy_plot(df, study_name, titles=titles)
        append_figure(pdf, fig, insights_prefix)

        # latency plot
        fig = latency_plot(df, study_name)
        if fig is not None:
            append_figure(pdf, fig, insights_prefix)

        display(Markdown(f"# Parameter Analysis"))

        # plot summary of all params
        fig, title = all_parameters_all_studies_plot(df, param_cols, group_col='study_name', titles=titles)
        append_figure(pdf, fig, insights_prefix, title=title)

        # plots for each parameter
        for param_col in param_cols:
            display(Markdown(f"### {descriptive_name(param_col)} ({param_col}):"))

            # param stats across all studies
            # if df["study_name"].nunique() > 1:
            #     fig, title = param_plot_all_studies(df, study_name, param_col)
            #     if fig is not None:
            #         append_figure(pdf, fig, insights_prefix, title=title)

            # param stats of invidual study
            fig, title = param_plot(df, study_name, param_col, titles=titles)
            append_figure(pdf, fig, insights_prefix, title=title)

            # parameter pareto frontier for individual study
            if (
                isinstance(param_col, str)
                and df[param_col].dtype == "O"
                and df[param_col].nunique(dropna=False) <= 10
            ):
                fig, title = param_pareto_plot(df, study_name, param_col, titles=titles)
                append_figure(pdf, fig, insights_prefix, title=title)

        display(Markdown(f"# Study Similarity Analysis"))

        if df["study_name"].nunique() > 1:
            # single study similar to all others
            fig = study_similarity_plot(df, study_name, titles=titles)
            append_figure(pdf, fig, insights_prefix)

            # correlation matrix of all studies
            fig = study_similarity_all_pairs_plot(df, study_name, titles=titles)
            append_figure(pdf, fig, insights_prefix)

        #     # matrix display of scatter plots
        #     # fig, title = study_similarity_all_pairs_scatters_plot(df, study_name, titles=titles)
        #     # append_figure(pdf, fig, insights_prefix, title=title)

        display(Markdown(f"# Miscellaneous Analysis"))

        # plot slowest components and params
        fig = slowest_components_plot(df, study_name)
        append_figure(pdf, fig, insights_prefix)

    return pdf_filename


async def filter_studies_and_generate_report(include_regex, exclude_regex, focus_regex, reset_cache, titles, insights_prefix=None):
    # cache
    if reset_cache:
        CACHE.clear()
        print("Cache cleared.")
    print(f"Cache volume: {CACHE.volume():,} bytes\n")

    # filter the study names
    study_names = get_study_names(
        include_regex=include_regex, 
        exclude_regex=exclude_regex,
    )

    # find the focus study
    focus_study = None
    for pattern in focus_regex:
        match = next((s for s in study_names if re.match(pattern, s)), None)
        if match is not None:
            focus_study = match
            break
    if focus_study is None:
        raise ValueError("No matching study found for given focus_regex")

    # generate charts and report
    print(f'Analyzing the focus study "{focus_study}" compared to {len(study_names)} other studies:')
    print("    " + "\n    ".join(study_names) + "\n")
    pdf_filename = await create_pdf_report(focus_study, all_study_names=study_names, titles=titles, insights_prefix=insights_prefix)
    full_path = os.path.abspath(pdf_filename)
    print(f"Report saved to: {full_path}")
    print("Done!", flush=True)
    return full_path


################################################################################
# EDIT BELOW: pick which by study names and generate the report
await filter_studies_and_generate_report(
    include_regex=[
        # ".*seeding1--training .*",
        # ".*seeding1--testing.*",
        "stefan--example1--small-study--drdocs_hf",
    ],
    exclude_regex=[
    ],
    focus_regex=[
        # ".*financebench.*",
        ".*",  # use first matching study
    ],
    reset_cache=True,
    titles={
        "cerebras4--rag-and-agents-cerebras-only--financebench_hf": "FinanceBench (Cerebras RAG and Agents)",
        "cerebras4--rag-and-agents-local-only--financebench_hf": "FinanceBench (Local RAG and Agents)",
        "cerebras4--rag-and-agents-cerebras-only--phantomwikiv050_hf--depth_20_size_10000_seed_3": "PhantomWiki (Cerebras RAG and Agents)",
        "cerebras4--rag-and-agents-local-only--phantomwikiv050_hf--depth_20_size_10000_seed_3": "PhantomWiki (Local RAG and Agents)",
        "cerebras5--rag-and-agents-cerebras-only--financebench_hf": "FinanceBench (Cerebras RAG and Agents)",
        "cerebras5--rag-and-agents-local-only--financebench_hf": "FinanceBench (Local RAG and Agents)",
        "cerebras5--rag-and-agents-cerebras-only--phantomwikiv050_hf--depth_20_size_10000_seed_3": "PhantomWiki (Cerebras RAG and Agents)",
        "cerebras5--rag-and-agents-local-only--phantomwikiv050_hf--depth_20_size_10000_seed_3": "PhantomWiki (Local RAG and Agents)",
    },
    insights_prefix="",  # the name prefix for exported figures
)
# EDIT ABOVE # ^^
################################################################################

# Debugging:
---
```python
# Get and filter study names
study_names = optuna.get_all_study_names(storage=cfg.postgres.get_optuna_storage())
study_names = [s for s in study_names if re.match(r'^bench10.*', s) is not None] # include filter
study_names = [s for s in study_names if re.match(r'.*synthetic_.*', s) is None] # exclude filter

# Load raw studies from optuna
def load_raw_study(name):
    df = optuna.load_study(study_name=name, storage=cfg.postgres.get_optuna_storage()).trials_dataframe()
    df['study_name'] = name
    return df
df = pd.concat([load_raw_study(name) for name in tqdm.tqdm(study_names)])
display(df.T)

# Load prepared studies
study_names = ['bench14--batch-1--financebench']
df, study_stats_table, exceptions_table = load_studies(study_names, only_successful_trials=False)
display(df.T)
```
---

In [None]:
# study_names = [
#     'bench14--batch-1--financebench',
# ]
# df, study_stats_table, exceptions_table = load_studies(study_names, only_successful_trials=False)
# # print(df.loc[df[df.params_rag_mode == "rag"].values_0.idxmax()].user_attrs_flow)
# rag_only_df = df[df.params_rag_mode == "rag"]
# for index, row in rag_only_df.iterrows():
#     print(row['user_attrs_flow'])

In [None]:
# df, study_stats_table, exceptions_table  = load_studies([
#     "bench13--batch-1--crag-music",
#     "bench13--batch-1--financebench",
#     "bench13--batch-1--hotpot-train-hard",
#     "bench13--batch-1--infinitebench",
# ])

# param_cols = [col for col in df.columns if col.startswith('params_') or 'prun' in col]
# corrs_table = what_correlates_with(df[param_cols], df['values_0'] == 0)

# pd.set_option("display.max_rows", 400)
# display(corrs_table)