# Study Query LLM - PCA KLLMeans Analysis & Visualization

This notebook loads previously saved sweep results from pickle files and provides:
1. Visualization of stability metrics across K values and summarizers
2. Data comparison and validation
3. Representative text comparison across summarizers
4. Custom analysis of results

**Prerequisites:** Run `pca_kllmeans_sweep.ipynb` first to generate pickle files.

## Install dependencies (REQUIRED - Run this first!)

**IMPORTANT:** You must install the package dependencies before running the notebook.

In [None]:
# Install the package in editable mode
%pip install -e ..

# After installation, restart the kernel (Kernel -> Restart Kernel) and run cells from the top

## Load Results from Pickle File

In [None]:
import pickle
import numpy as np
from pathlib import Path

# Option 1: Specify the exact filename
# saved_file = "pca_kllmeans_sweep_results_20260204_231913.pkl"

# Option 2: Load the most recent pickle file (default)
pickle_files = sorted(Path(".").glob("pca_kllmeans_sweep_results_*.pkl"), reverse=True)
if pickle_files:
    saved_file = pickle_files[0]
    print(f"[INFO] Loading results from: {saved_file}")
    
    with open(saved_file, "rb") as f:
        loaded_data = pickle.load(f)
    
    # Handle both old format (dict of summarizers) and new format (with metadata)
    if "summarizers" in loaded_data:
        # New format with metadata
        results = loaded_data["summarizers"]
        ground_truth_labels = loaded_data.get("ground_truth_labels")
        dataset_name = loaded_data.get("dataset_name", "unknown")
        print(f"[OK] Loaded new format: {len(results)} summarizer(s), dataset: {dataset_name}")
        if ground_truth_labels is not None:
            ground_truth_labels = np.array(ground_truth_labels)  # Convert back to numpy array
            print(f"   Ground truth: {len(ground_truth_labels)} samples, {len(set(ground_truth_labels))} clusters")
        else:
            print(f"   Ground truth: None")
    else:
        # Old format (backward compatible)
        results = loaded_data
        ground_truth_labels = None
        dataset_name = "unknown"
        print(f"[OK] Loaded old format: {len(results)} summarizer(s) (backward compatible)")
    
    # Show structure
    if results:
        first_key = list(results.keys())[0]
        first_data = results[first_key]
        if 'by_k' in first_data:
            k_values = sorted([int(k) for k in first_data['by_k'].keys()])
            print(f"   K values: {k_values}")
            print(f"   Example access: results['{first_key}']['by_k']['{k_values[0]}']['stability']")
else:
    print("[ERROR] No pickle files found. Run pca_kllmeans_sweep.ipynb first to generate results.")
    results = {}
    ground_truth_labels = None
    dataset_name = "unknown"

In [None]:
from collections import defaultdict
from typing import Any, Dict, Set, Tuple


def analyze_dict_structure(
    data: Any,
    indent: int = 0,
    seen_structures: Dict[Tuple[Tuple[type, ...], Tuple[type, ...]], int] = None,
    path: str = "root"
) -> None:
    """
    Analyze and print the structure of a nested dictionary.
    
    For each dictionary level, shows:
    - Number of dictionaries with that structure
    - Key types
    - Value types (recursively if values are dictionaries)
    
    Args:
        data: The dictionary (or any data) to analyze
        indent: Current indentation level (for recursive calls)
        seen_structures: Dictionary tracking structure counts (internal use)
        path: Current path in the structure (for debugging)
    """
    if seen_structures is None:
        seen_structures = defaultdict(int)
    
    if not isinstance(data, dict):
        return
    
    # Collect key and value types for this dictionary
    key_types: Set[type] = set()
    value_types: Set[type] = set()
    nested_dicts = []
    
    for key, value in data.items():
        key_types.add(type(key))
        
        value_type = type(value)
        value_types.add(value_type)
        
        if isinstance(value, dict):
            nested_dicts.append((key, value))
    
    # Create a signature for this structure
    key_types_tuple = tuple(sorted(key_types, key=lambda t: t.__name__))
    value_types_tuple = tuple(sorted(value_types, key=lambda t: t.__name__))
    structure_sig = (key_types_tuple, value_types_tuple)
    
    # Count this structure
    seen_structures[structure_sig] += 1
    count = seen_structures[structure_sig]
    
    # Print this level's structure
    indent_str = "  " * indent
    print(f"{indent_str}Level {indent}: {count} dict(s) with this structure")
    print(f"{indent_str}  Keys: {' | '.join(t.__name__ for t in sorted(key_types, key=lambda t: t.__name__))}")
    
    # Format value types, marking dicts specially
    value_type_names = []
    for vt in sorted(value_types, key=lambda t: t.__name__):
        if vt == dict:
            value_type_names.append("dict (nested)")
        else:
            value_type_names.append(vt.__name__)
    print(f"{indent_str}  Values: {' | '.join(value_type_names)}")
    
    # Recursively process nested dictionaries
    if nested_dicts:
        print(f"{indent_str}  Nested dictionaries:")
        for key, nested_dict in nested_dicts:
            print(f"{indent_str}    -> '{key}' (dict)")
            analyze_dict_structure(
                nested_dict,
                indent=indent + 2,
                seen_structures=seen_structures,
                path=f"{path}.{key}"
            )


def print_dict_structure(data: Dict) -> None:
    """
    Convenience wrapper to print dictionary structure starting from root.
    
    Args:
        data: The dictionary to analyze
    """
    print("Dictionary Structure Analysis:")
    print("=" * 50)
    analyze_dict_structure(data)
    print("=" * 50)


example = {
    "a": 1,
    "b": "string",
    "c": {
        "d": 2,
        "e": {
            "f": 3,
            "g": "nested_string"
        }
    },
    "h": {
        "i": 4
    }
}

print_dict_structure(example)
print_dict_structure(results)

## Visualize Stability Metrics

Create visualizations of stability metrics across K values and summarizers.

In [None]:
import matplotlib.pyplot as plt
import numpy as np

if results:
    metrics_to_plot = [
        ("silhouette", "Cosine Silhouette (mean ± std)"),
        ("stability_ari", "Stability ARI (mean ± std)"),
        ("dispersion", "Reconstruction Error / point (mean ± std)"),
        ("coverage", "Coverage Fraction (mean ± std)"),
    ]
    
    # Get coverage threshold from saved data if available
    coverage_threshold = 0.2  # Default, update if stored in results
    
    fig, axes = plt.subplots(2, 2, figsize=(14, 10), sharex=True)
    axes = axes.flatten()
    
    # Use different line styles and markers to distinguish overlapping lines
    line_styles = ['-', '--', '-.', ':']
    markers = ['o', 's', '^', 'D', 'v', 'p', '*', 'h']
    colors = plt.cm.tab10(np.linspace(0, 1, len(results)))
    
    for ax, (metric_key, title) in zip(axes, metrics_to_plot):
        plotted_count = 0
        for idx, summarizer_name in enumerate(sorted(results.keys())):
            try:
                k_data = results[summarizer_name]["by_k"]
                ks = sorted([int(k) for k in k_data.keys() if k_data[k].get("stability")])
                
                if not ks:
                    continue
                    
                means = []
                stds = []
                valid_ks = []
                
                for k in ks:
                    stability = k_data[str(k)].get("stability")
                    if stability and metric_key in stability:
                        means.append(stability[metric_key]["mean"])
                        stds.append(stability[metric_key]["std"])
                        valid_ks.append(k)
                
                if valid_ks:
                    # Use different styles for each summarizer
                    style_idx = idx % len(line_styles)
                    marker_idx = idx % len(markers)
                    color = colors[idx]
                    
                    ax.plot(
                        valid_ks, means, 
                        marker=markers[marker_idx],
                        linestyle=line_styles[style_idx],
                        label=summarizer_name, 
                        linewidth=2.5,
                        color=color,
                        markersize=6,
                        markeredgewidth=1.5
                    )
                    ax.fill_between(
                        valid_ks,
                        np.array(means) - np.array(stds),
                        np.array(means) + np.array(stds),
                        alpha=0.15,
                        color=color
                    )
                    plotted_count += 1
            except Exception as e:
                print(f"[WARN]  Error plotting {summarizer_name} for {metric_key}: {e}")
                continue
        
        if plotted_count == 0:
            print(f"[WARN]  No data to plot for {metric_key}")
        
        ax.set_title(title, fontsize=11, fontweight='bold')
        ax.set_xlabel("K (number of clusters)", fontsize=10)
        ax.set_ylabel(metric_key.replace('_', ' ').title(), fontsize=10)
        ax.grid(True, alpha=0.3)
        ax.legend(fontsize=8)
    
    # Add overall title
    title_parts = [f"Stability Metrics vs K for Different Summarizers"]
    if 'dataset_name' in globals() and dataset_name != "unknown":
        title_parts.append(f"Dataset: {dataset_name}")
    title_parts.append(f"Coverage threshold: {coverage_threshold}")
    
    fig.suptitle(
        " | ".join(title_parts),
        fontsize=12,
        fontweight='bold',
        y=0.995
    )
    
    plt.tight_layout(rect=[0, 0, 1, 0.97])
    plt.show()
    
    print(f"\n[OK] Visualized {len(results)} summarizer(s) across K values")
else:
    print("[WARN]  No results to visualize")

## Ground Truth Comparison (if available)

Compare predicted cluster labels against ground truth labels using Adjusted Rand Index (ARI).

In [None]:
# Ground truth comparison visualization
if 'ground_truth_labels' in globals() and ground_truth_labels is not None:
    from study_query_llm.algorithms.clustering import adjusted_rand_index
    
    print(f"[INFO] Computing ground truth ARI for dataset: {dataset_name}")
    print(f"   Ground truth clusters: {len(set(ground_truth_labels))}")
    
    # Compute ARI for each K and summarizer
    gt_ari_data = {}
    for summarizer_name in sorted(results.keys()):
        gt_ari_data[summarizer_name] = {}
        k_data = results[summarizer_name]["by_k"]
        ks = sorted([int(k) for k in k_data.keys() if k_data[k].get("labels") is not None])
        
        for k in ks:
            labels = np.array(k_data[str(k)]["labels"])
            if len(labels) == len(ground_truth_labels):
                ari = adjusted_rand_index(labels, ground_truth_labels)
                gt_ari_data[summarizer_name][k] = ari
    
    # Create visualization
    fig, ax = plt.subplots(1, 1, figsize=(10, 6))
    
    # Plot ARI vs K for each summarizer
    line_styles = ['-', '--', '-.', ':']
    markers = ['o', 's', '^', 'D', 'v', 'p', '*', 'h']
    colors = plt.cm.tab10(np.linspace(0, 1, len(gt_ari_data)))
    
    for idx, summarizer_name in enumerate(sorted(gt_ari_data.keys())):
        ks = sorted(gt_ari_data[summarizer_name].keys())
        aris = [gt_ari_data[summarizer_name][k] for k in ks]
        
        style_idx = idx % len(line_styles)
        marker_idx = idx % len(markers)
        color = colors[idx]
        
        ax.plot(
            ks, aris,
            marker=markers[marker_idx],
            linestyle=line_styles[style_idx],
            label=summarizer_name,
            linewidth=2.5,
            color=color,
            markersize=8,
            markeredgewidth=1.5
        )
    
    # Add vertical line at true number of clusters
    true_k = len(set(ground_truth_labels))
    ax.axvline(x=true_k, color='red', linestyle='--', linewidth=2, alpha=0.7, label=f'True K={true_k}')
    
    ax.set_xlabel("K (number of clusters)", fontsize=11)
    ax.set_ylabel("Adjusted Rand Index vs Ground Truth", fontsize=11)
    ax.set_title(f"Ground Truth Comparison: {dataset_name}", fontsize=12, fontweight='bold')
    ax.grid(True, alpha=0.3)
    ax.legend(fontsize=9)
    ax.set_ylim([-0.1, 1.1])
    
    plt.tight_layout()
    plt.show()
    
    # Print best K for each summarizer
    print(f"\n[INFO] Best K by ground truth ARI:")
    for summarizer_name in sorted(gt_ari_data.keys()):
        best_k = max(gt_ari_data[summarizer_name].items(), key=lambda x: x[1])
        print(f"   {summarizer_name}: K={best_k[0]} (ARI={best_k[1]:.3f})")
    
else:
    print("[INFO] No ground truth labels available for this dataset.")
    print("   Ground truth comparison requires benchmark datasets with known cluster labels.")

In [None]:
# Compare representatives for a specific K value
K_TO_COMPARE = 5

print(f"\n{'='*60}")
print(f"Comparing representatives for K={K_TO_COMPARE}")
print(f"{'='*60}")

for summarizer_name in sorted(results.keys()):
    if str(K_TO_COMPARE) in results[summarizer_name]["by_k"]:
        reps = results[summarizer_name]["by_k"][str(K_TO_COMPARE)].get("representatives", [])
        print(f"\n{summarizer_name}:")
        for i, rep in enumerate(reps, 1):
            print(f"  {i}. {rep[:120]}{'...' if len(rep) > 120 else ''}")

## Test: Check if Data is Identical or Just Very Close

Test whether summarizers have truly identical data (bit-for-bit) or if values are just very close (within floating-point precision).

In [None]:
import numpy as np

if results:
    print(f"\n{'='*60}")
    print("PRECISION TEST: Identical vs Very Close")
    print(f"{'='*60}\n")
    
    # Test different tolerance levels
    tolerances = [
        (1e-15, "Machine epsilon (1e-15)"),
        (1e-12, "Very tight (1e-12)"),
        (1e-10, "Tight (1e-10)"),
        (1e-8, "Moderate (1e-8)"),
        (1e-6, "Loose (1e-6)"),
    ]
    
    # Get all summarizers
    summarizers = sorted(results.keys())
    
    if len(summarizers) < 2:
        print("[WARN]  Need at least 2 summarizers to compare")
    else:
        # Test for a specific K value across all metrics
        test_k = "5"
        metrics_to_test = ["silhouette", "stability_ari", "dispersion", "coverage"]
        
        print(f"Testing K={test_k} across {len(summarizers)} summarizers\n")
        
        # Collect all metric values
        metric_data = {}
        for metric_key in metrics_to_test:
            metric_data[metric_key] = {}
            for name in summarizers:
                k_data = results[name]["by_k"]
                if test_k in k_data and k_data[test_k].get("stability"):
                    stability = k_data[test_k]["stability"]
                    if metric_key in stability:
                        metric_data[metric_key][name] = stability[metric_key]["mean"]
        
        # Test each metric
        for metric_key in metrics_to_test:
            if not metric_data[metric_key]:
                continue
                
            print(f"\n{metric_key.upper().replace('_', ' ')}:")
            print("-" * 60)
            
            values = list(metric_data[metric_key].values())
            if len(values) < 2:
                print("  [WARN]  Not enough data to compare")
                continue
            
            # Check if all values are exactly equal (bit-for-bit identical)
            all_identical = all(v == values[0] for v in values)
            
            if all_identical:
                print(f"  [OK] BIT-FOR-BIT IDENTICAL: All values = {values[0]}")
                print(f"     This means the data is truly identical, not just close")
            else:
                # Calculate differences
                min_val = min(values)
                max_val = max(values)
                diff = max_val - min_val
                rel_diff = diff / abs(min_val) if min_val != 0 else diff
                
                print(f"  Values: {[f'{v:.12f}' for v in values]}")
                print(f"  Range: [{min_val:.12f}, {max_val:.12f}]")
                print(f"  Absolute difference: {diff:.2e}")
                print(f"  Relative difference: {rel_diff:.2e}")
                
                # Test against each tolerance
                print(f"\n  Tolerance tests:")
                for tol, desc in tolerances:
                    within_tol = diff < tol
                    status = "[OK]" if within_tol else "[ERROR]"
                    print(f"    {status} {desc}: {'WITHIN' if within_tol else 'EXCEEDS'} tolerance")
                
                # Determine if "very close" or "different"
                if diff < 1e-12:
                    print(f"\n  [INFO] CONCLUSION: Values are VERY CLOSE (likely same computation, minor FP differences)")
                elif diff < 1e-8:
                    print(f"\n  [INFO] CONCLUSION: Values are CLOSE (may be same computation with some variation)")
                else:
                    print(f"\n  [INFO] CONCLUSION: Values are DIFFERENT (likely different computations)")
        
        # Additional test: Check if all metrics are identical across all K values
        print(f"\n{'='*60}")
        print("ACROSS-ALL-K TEST")
        print(f"{'='*60}\n")
        
        all_k_identical = True
        all_k_very_close = True
        
        for k in sorted([int(k) for k in results[summarizers[0]]["by_k"].keys()]):
            k_str = str(k)
            for metric_key in metrics_to_test:
                values = []
                for name in summarizers:
                    k_data = results[name]["by_k"]
                    if k_str in k_data and k_data[k_str].get("stability"):
                        stability = k_data[k_str]["stability"]
                        if metric_key in stability:
                            values.append(stability[metric_key]["mean"])
                
                if len(values) > 1:
                    # Check if identical
                    if not all(v == values[0] for v in values):
                        all_k_identical = False
                    # Check if very close
                    if max(values) - min(values) >= 1e-10:
                        all_k_very_close = False
        
        if all_k_identical:
            print("[OK] ALL metrics are BIT-FOR-BIT IDENTICAL across all K values and summarizers")
            print("   This confirms the data is truly identical (same computation)")
        elif all_k_very_close:
            print("[OK] ALL metrics are VERY CLOSE (< 1e-10) across all K values and summarizers")
            print("   Values are essentially identical (likely same computation with minor FP differences)")
        else:
            print("[WARN]  Some metrics differ significantly across K values or summarizers")
            print("   This suggests there may be actual differences in the computations")
else:
    print("[WARN]  No results to test")