# Study Query LLM - PCA KLLMeans Sweep (Jupyter)

This notebook:
1) Loads prompt texts from the dictionary
2) Fetches/stores embeddings using `EmbeddingService` (v2 schema)
3) Runs the PCA KLLMeans sweep using the `algorithms` library
4) Tests multiple LLM summarizers (3 LLMs + None) for stability analysis
5) Tracks provenance using `ProvenanceService`

**Note:** This is the Jupyter version (not Colab-specific). For Colab, use `colab_pca_kllmeans_sweep.ipynb`.

## Install dependencies (if needed)

Uncomment if running in a fresh environment:

In [None]:
# %pip install openai python-dotenv sqlalchemy psycopg2-binary nest_asyncio tqdm numpy
# Note: python-dotenv is optional - the notebook will work without it

## Configure environment

Set your environment variables or load from `.env` file:

In [None]:
import os

# Try to load from .env file if python-dotenv is available
try:
    from dotenv import load_dotenv
    load_dotenv()
    print("✅ Loaded environment variables from .env file (if present)")
except ImportError:
    print("ℹ️  python-dotenv not installed. Skipping .env file loading.")
    print("   Install with: pip install python-dotenv")
    print("   Or set environment variables directly below.")

# Set environment variables (or use .env file)
os.environ.setdefault("AZURE_OPENAI_API_KEY", "your-azure-api-key")
os.environ.setdefault("AZURE_OPENAI_ENDPOINT", "https://your-resource.openai.azure.com/")
os.environ.setdefault("AZURE_OPENAI_EMBEDDING_DEPLOYMENT", "text-embedding-3-small")
os.environ.setdefault("AZURE_OPENAI_API_VERSION", "2024-02-15-preview")
os.environ.setdefault("AZURE_OPENAI_DEPLOYMENT", "gpt-4o-mini")
os.environ.setdefault(
    "DATABASE_URL",
    "postgresql://username:password@host:port/database?sslmode=require"
)

print("\nEnvironment variables:")
print("AZURE_OPENAI_ENDPOINT:", os.environ.get("AZURE_OPENAI_ENDPOINT"))
print("AZURE_OPENAI_EMBEDDING_DEPLOYMENT:", os.environ.get("AZURE_OPENAI_EMBEDDING_DEPLOYMENT"))
print("AZURE_OPENAI_DEPLOYMENT:", os.environ.get("AZURE_OPENAI_DEPLOYMENT"))
print("DATABASE_URL set:", bool(os.environ.get("DATABASE_URL")))

ModuleNotFoundError: No module named 'dotenv'

## Initialize database and services

In [None]:
import time
import asyncio
import nest_asyncio
import numpy as np
from tqdm import tqdm
from study_query_llm.db.connection_v2 import DatabaseConnectionV2
from study_query_llm.db.raw_call_repository import RawCallRepository
from study_query_llm.services.embedding_service import EmbeddingService, EmbeddingRequest
from study_query_llm.services.summarization_service import SummarizationService, SummarizationRequest
from study_query_llm.services.provenance_service import ProvenanceService
from study_query_llm.algorithms import SweepConfig, run_sweep

nest_asyncio.apply()

# Initialize DB connection (creates tables if needed)
db = DatabaseConnectionV2(os.environ["DATABASE_URL"], enable_pgvector=True)
db.init_db()

# Get embedding deployment
embedding_deployment = os.environ["AZURE_OPENAI_EMBEDDING_DEPLOYMENT"]
print(f"Using embedding deployment: {embedding_deployment}")
print("✅ Database and services ready (will be created per-operation)")

## Prompt dictionary

Load your prompt dictionary here. You can:
1. Define it directly in the cell below
2. Load from a pickle file: `database_estela_dict = pickle.load(open('file.pkl', 'rb'))`
3. Load from a JSON file: `import json; database_estela_dict = json.load(open('file.json', 'r'))`

In [None]:
# Option 1: Define directly
# database_estela_dict = {
#     'path/to/file.yaml': {
#         'generation prompts': [
#             {'prompt 1': 'Your prompt text here...'},
#         ],
#     },
# }

# Option 2: Load from pickle
# import pickle
# database_estela_dict = pickle.load(open('your_dict.pkl', 'rb'))

# Option 3: Load from JSON
# import json
# database_estela_dict = json.load(open('your_dict.json', 'r', encoding='utf-8'))

# For now, create empty dict - you must define it above
database_estela_dict = {}

## Flatten prompts and extract texts

In [None]:
# Flatten nested prompt dictionary into a flat map of key tuples -> prompt strings

def _is_prompt_key(key: str) -> bool:
    key_lower = key.lower()
    return "prompt" in key_lower  # catches "prompt", "prompt id", "prompt #", etc


def flatten_prompt_dict(data, path=()):
    flat = {}

    if isinstance(data, dict):
        for key, value in data.items():
            new_path = path + (key,)
            if isinstance(key, str) and _is_prompt_key(key) and isinstance(value, str):
                flat[new_path] = value
            else:
                flat.update(flatten_prompt_dict(value, new_path))
    elif isinstance(data, list):
        for i, value in enumerate(data):
            new_path = path + (f"[{i}]",)
            flat.update(flatten_prompt_dict(value, new_path))

    return flat


# Flatten prompts and extract texts
flat_prompts = flatten_prompt_dict(database_estela_dict)
texts = list(flat_prompts.values())

# Filter out empty/invalid strings (EmbeddingService will reject them anyway)
def _clean_texts(texts_list: list[str]) -> list[str]:
    """Clean and filter texts."""
    cleaned = []
    for text in texts_list:
        if text is None:
            continue
        if not isinstance(text, str):
            text = str(text)
        text = text.replace("\x00", "").strip()
        if text:  # Only keep non-empty strings
            cleaned.append(text)
    return cleaned


texts = _clean_texts(texts)
print(f"✅ Flattened {len(texts)} valid prompts")

# Show a few samples
for i, (k, v) in enumerate(list(flat_prompts.items())[:3]):
    print(f"\nSample {i+1} - Key: {k}")
    print(f"  Text: {v[:150]}{'...' if len(v) > 150 else ''}")

## Fetch embeddings using EmbeddingService

In [None]:
async def fetch_embeddings_async(texts_list: list[str], deployment: str) -> np.ndarray:
    """Fetch or create embeddings using EmbeddingService."""
    with db.session_scope() as session:
        repo = RawCallRepository(session)
        service = EmbeddingService(repository=repo)

        # Create embedding requests
        requests = [
            EmbeddingRequest(text=text, deployment=deployment)
            for text in texts_list
        ]

        # Get embeddings (will use cache if available)
        responses = await service.get_embeddings_batch(requests)

        # Extract vectors
        embeddings = [resp.vector for resp in responses]

        return np.asarray(embeddings, dtype=np.float64)


# Fetch embeddings
print(f"Fetching embeddings for {len(texts)} texts using {embedding_deployment}...")
embeddings = await fetch_embeddings_async(texts, embedding_deployment)
print(f"✅ Got embeddings: shape {embeddings.shape}")

## Run PCA KLLMeans Sweep with Multiple LLM Summarizers

In [None]:
# Define LLM deployments for summarization (3 LLMs + None = 4 runs)
llm_summarizers = [
    None,  # Non-LLM summaries (just use original representatives)
    "gpt-4o-mini",
    "gpt-4o",
    "gpt-5-chat-2025-08-07",  # Add your preferred third LLM here
]

# Configure sweep: k=2 to k=10
# Enable stability metrics with multiple restarts
cfg = SweepConfig(
    pca_dim=64,
    rank_r=2,
    k_min=2,
    k_max=10,
    max_iter=200,
    base_seed=0,
    n_restarts=20,  # Multiple restarts for stability analysis
    compute_stability=True,  # Enable stability metrics (silhouette, ARI, dispersion, coverage)
    coverage_threshold=0.2,  # Cosine distance threshold for coverage metric
)

# Helper to create paraphraser using SummarizationService
def create_paraphraser_for_llm(llm_deployment: str):
    """Create a synchronous paraphraser function for a specific LLM deployment."""
    if llm_deployment is None:
        return None

    async def _paraphrase_batch_async(texts: list[str]) -> list[str]:
        """Async wrapper for summarization."""
        with db.session_scope() as session:
            repo = RawCallRepository(session)
            service = SummarizationService(repository=repo)

            request = SummarizationRequest(
                texts=texts,
                llm_deployment=llm_deployment,
                temperature=0.2,
                max_tokens=128,
            )

            result = await service.summarize_batch(request)
            return result.summaries

    def paraphrase_batch_sync(texts: list[str]) -> list[str]:
        """Synchronous wrapper for run_sweep."""
        try:
            loop = asyncio.get_event_loop()
            if loop.is_running():
                # If loop is running, we need to use nest_asyncio
                return loop.run_until_complete(_paraphrase_batch_async(texts))
        except RuntimeError:
            pass
        return asyncio.run(_paraphrase_batch_async(texts))

    return paraphrase_batch_sync


# Store all results
all_results = {}

# Create run group for provenance tracking
with db.session_scope() as session:
    repo = RawCallRepository(session)
    provenance = ProvenanceService(repository=repo)

    run_group_id = provenance.create_run_group(
        name=f"pca_kllmeans_sweep_{embedding_deployment}",
        metadata={
            "embedding_deployment": embedding_deployment,
            "n_texts": len(texts),
            "k_range": f"{cfg.k_min}-{cfg.k_max}",
            "llm_summarizers": [s if s else "None" for s in llm_summarizers],
        },
    )
    print(f"✅ Created run group: id={run_group_id}")

# Run sweep for each LLM summarizer
for llm_deployment in tqdm(llm_summarizers, desc="LLM Summarizers"):
    summarizer_name = "None" if llm_deployment is None else llm_deployment
    print(f"\n{'='*60}")
    print(f"Running sweep with summarizer: {summarizer_name}")
    print(f"{'='*60}")

    # Create paraphraser
    paraphraser = create_paraphraser_for_llm(llm_deployment)

    # Run sweep
    result = run_sweep(texts, embeddings, cfg, paraphraser=paraphraser)
    all_results[summarizer_name] = result

    print(f"✅ Completed. Ks: {sorted([int(k) for k in result.by_k.keys()])}")

# Summary
print(f"\n{'='*60}")
print("SUMMARY")
print(f"{'='*60}")
print(f"Embedding deployment: {embedding_deployment}")
print(f"Number of texts: {len(texts)}")
print(f"K range: {cfg.k_min} to {cfg.k_max}")
print(
    f"Summarizers tested: {len(llm_summarizers)} ({', '.join([s if s else 'None' for s in llm_summarizers])})"
)
print(f"\nResults structure: all_results[summarizer_name]['by_k'][k_value]")
print(f"\nExample access:")
print(f"  all_results['None']['by_k']['5']['representatives']")
if any(s for s in llm_summarizers if s):
    first_llm = next(s for s in llm_summarizers if s)
    print(f"  all_results['{first_llm}']['by_k']['5']['representatives']")

## View Results

In [None]:
# Display results for each summarizer
for summarizer_name, result in all_results.items():
    print(f"\n{'='*60}")
    print(f"Results for summarizer: {summarizer_name}")
    print(f"{'='*60}")

    for k in sorted([int(k) for k in result.by_k.keys()]):
        k_data = result.by_k[str(k)]
        reps = k_data.get("representatives", [])
        print(f"\nK={k}: {len(reps)} representatives")
        for i, rep in enumerate(reps[:3], 1):  # Show first 3
            print(f"  {i}. {rep[:100]}{'...' if len(rep) > 100 else ''}")
        if len(reps) > 3:
            print(f"  ... and {len(reps) - 3} more")
        
        # Show stability metrics if available
        if k_data.get("stability"):
            stab = k_data["stability"]
            print(f"    Silhouette: {stab['silhouette']['mean']:.3f} ± {stab['silhouette']['std']:.3f}")
            print(f"    Stability ARI: {stab['stability_ari']['mean']:.3f} ± {stab['stability_ari']['std']:.3f}")

## Save Results (Complete - All Metrics + Matrices)

In [None]:
# Save results to a pickle file for later analysis
# Includes all metrics, labels, objectives, and distance matrices for recomputation
import pickle
from datetime import datetime

timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
output_file = f"pca_kllmeans_sweep_results_{timestamp}.pkl"

# Prepare results for saving (convert numpy arrays to lists for pickle compatibility)
save_results = {}
for summarizer_name, result in all_results.items():
    save_results[summarizer_name] = {
        "pca": result.pca,
        "by_k": {},
    }

    # Save distance matrices for later metric computation
    if result.Z is not None:
        save_results[summarizer_name]["Z"] = result.Z.tolist()
    if result.Z_norm is not None:
        save_results[summarizer_name]["Z_norm"] = result.Z_norm.tolist()
    if result.dist is not None:
        save_results[summarizer_name]["dist"] = result.dist.tolist()

    # Save all data for each K value
    for k, k_data in result.by_k.items():
        save_results[summarizer_name]["by_k"][k] = {
            "representatives": k_data.get("representatives", []),
            "labels": (
                k_data.get("labels", []).tolist()
                if hasattr(k_data.get("labels"), "tolist")
                else k_data.get("labels", [])
            ),
            "labels_all": (
                [
                    l.tolist() if hasattr(l, "tolist") else l
                    for l in k_data.get("labels_all", [])
                ]
                if k_data.get("labels_all") is not None
                else None
            ),
            "objective": k_data.get("objective", {}),
            "objectives": k_data.get("objectives", []),  # All restart objectives
            "stability": k_data.get("stability"),  # Stability metrics (silhouette, ARI, dispersion, coverage)
        }

with open(output_file, "wb") as f:
    pickle.dump(save_results, f)

print(f"✅ Results saved to: {output_file}")
print(
    f"   Includes: representatives, labels, objectives, stability metrics, and distance matrices"
)
print(f"   To load: results = pickle.load(open('{output_file}', 'rb'))")
print(f"\n   Example access:")
print(f"     results['None']['by_k']['5']['stability']['silhouette']['mean']")
print(f"     results['None']['by_k']['5']['representatives']")
print(f"     results['None']['dist']  # Distance matrix for recomputing metrics")

## Compare Representatives Across Summarizers

In [None]:
# Compare representatives for a specific K value
K_TO_COMPARE = 5

print(f"\n{'='*60}")
print(f"Comparing representatives for K={K_TO_COMPARE}")
print(f"{'='*60}")

for summarizer_name in sorted(all_results.keys()):
    if str(K_TO_COMPARE) in all_results[summarizer_name].by_k:
        reps = all_results[summarizer_name].by_k[str(K_TO_COMPARE)].get(
            "representatives", []
        )
        print(f"\n{summarizer_name}:")
        for i, rep in enumerate(reps, 1):
            print(f"  {i}. {rep[:120]}{'...' if len(rep) > 120 else ''}")

## Notes

- All embeddings are cached in the database (v2 schema: `RawCall` + `EmbeddingVector`)
- All LLM summarization calls are logged to `RawCall` with full provenance
- Results are stored in `all_results[summarizer_name]['by_k'][k_value]`
- **Stability metrics enabled**: 20 restarts per K with full metrics (silhouette, ARI, dispersion, coverage)
- **Complete data saved**: All metrics, labels, objectives, and distance matrices saved to pickle for later analysis
- You can load the pickle file later and recompute or analyze metrics without re-running the sweep

## Load and Analyze Saved Results (Later)

Use this cell to load previously saved results and analyze metrics without re-running the sweep.

In [None]:
# Load saved results
# import pickle
# import numpy as np

# Replace with your actual pickle filename
# saved_file = "pca_kllmeans_sweep_results_20250203_123456.pkl"
# results = pickle.load(open(saved_file, "rb"))

# Example: Access stability metrics
# for summarizer_name in results.keys():
#     print(f"\n{summarizer_name}:")
#     for k in sorted([int(k) for k in results[summarizer_name]["by_k"].keys()]):
#         k_data = results[summarizer_name]["by_k"][k]
#         if k_data.get("stability"):
#             stab = k_data["stability"]
#             print(f"  K={k}:")
#             print(f"    Silhouette: {stab['silhouette']['mean']:.3f} ± {stab['silhouette']['std']:.3f}")
#             print(f"    Stability ARI: {stab['stability_ari']['mean']:.3f} ± {stab['stability_ari']['std']:.3f}")
#             print(f"    Dispersion: {stab['dispersion']['mean']:.3f} ± {stab['dispersion']['std']:.3f}")
#             print(f"    Coverage: {stab['coverage']['mean']:.3f} ± {stab['coverage']['std']:.3f}")

# Example: Recompute metrics using saved distance matrices
# if "dist" in results["None"]:
#     dist_matrix = np.array(results["None"]["dist"])
#     labels = np.array(results["None"]["by_k"]["5"]["labels"])
#     # Now you can recompute metrics using the algorithms library functions