# Local HF Pipeline - End-to-End Test

This notebook validates that the **fully local** KohakuRAG pipeline works
without any network calls. It tests:

1. Local embeddings (`LocalHFEmbeddingModel` via sentence-transformers)
2. Local LLM chat (`HuggingFaceLocalChatModel` via transformers)
3. Full RAG pipeline: index documents, retrieve, and answer

**Prerequisites:**
- Kernel: `kohaku-gb10` (or your project venv)
- Dependencies installed: `pip install -r local_requirements.txt`
- Vendored packages installed: `pip install -e vendor/KohakuVault && pip install -e vendor/KohakuRAG`

## Step 1 - Verify imports

In [None]:
import torch
import transformers
import sentence_transformers

print(f"torch:                {torch.__version__}")
print(f"transformers:         {transformers.__version__}")
print(f"sentence-transformers: {sentence_transformers.__version__}")
print(f"CUDA available:       {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU:                  {torch.cuda.get_device_name(0)}")
print()

import kohakurag
import kohakuvault
print(f"kohakurag:  {kohakurag.__file__}")
print(f"kohakuvault: {kohakuvault.__file__}")
print("\nAll imports OK")

## Step 2 - Test local embeddings

In [None]:
from kohakurag.embeddings import LocalHFEmbeddingModel

# Use a small, fast model for testing
embedder = LocalHFEmbeddingModel(model_name="BAAI/bge-base-en-v1.5")
print(f"Embedding model loaded: BAAI/bge-base-en-v1.5")
print(f"Embedding dimension:    {embedder.dimension}")

In [None]:
import numpy as np

test_texts = [
    "Solar panels convert sunlight into electricity.",
    "Photovoltaic cells generate power from solar radiation.",
    "The capital of France is Paris.",
]

vecs = await embedder.embed(test_texts)
print(f"Embedding shape: {vecs.shape}")
print(f"Dtype:           {vecs.dtype}")

# Cosine similarity (vectors are already normalized)
sim_01 = float(np.dot(vecs[0], vecs[1]))
sim_02 = float(np.dot(vecs[0], vecs[2]))
print(f"\nSimilarity (solar vs photovoltaic): {sim_01:.4f}  (should be high)")
print(f"Similarity (solar vs Paris):         {sim_02:.4f}  (should be low)")
assert sim_01 > sim_02, "Semantic similarity check failed!"
print("\nEmbedding sanity check PASSED")

## Step 3 - Test local LLM chat

This loads a local HF model for generation. The default is `Qwen/Qwen2.5-7B-Instruct`.

**Note:** If this is too large for your GPU, change `LLM_MODEL_ID` to a smaller model
like `Qwen/Qwen2.5-1.5B-Instruct` or `TinyLlama/TinyLlama-1.1B-Chat-v1.0`.

In [None]:
# Configure the LLM model - adjust if needed for your hardware
LLM_MODEL_ID = "Qwen/Qwen2.5-7B-Instruct"  # change to smaller model if OOM
LLM_DTYPE = "bf16"  # "bf16", "fp16", or "auto"

print(f"Will load: {LLM_MODEL_ID} ({LLM_DTYPE})")

In [None]:
from kohakurag.llm import HuggingFaceLocalChatModel

chat = HuggingFaceLocalChatModel(
    model=LLM_MODEL_ID,
    dtype=LLM_DTYPE,
    max_new_tokens=256,
    temperature=0.0,  # greedy for reproducibility
)
print(f"LLM loaded: {LLM_MODEL_ID}")

In [None]:
response = await chat.complete(
    "What is 2 + 2? Answer with just the number.",
    system_prompt="You are a helpful assistant. Be concise.",
)
print(f"LLM response: {response!r}")
assert "4" in response, f"Expected '4' in response, got: {response}"
print("LLM sanity check PASSED")

## Step 4 - Full RAG pipeline with train_QA.csv

This step loads real WattBot questions from `data/train_QA.csv`, creates
a small document corpus from our sample data, indexes it with proper
hierarchy (document -> paragraph -> sentence), and runs retrieval + QA.

In [None]:
import csv
from pathlib import Path

# Load train_QA.csv
qa_path = Path("../data/train_QA.csv")
if not qa_path.exists():
    qa_path = Path("data/train_QA.csv")  # fallback if running from repo root

with qa_path.open(newline="", encoding="utf-8-sig") as f:
    reader = csv.DictReader(f)
    qa_rows = list(reader)

print(f"Loaded {len(qa_rows)} questions from {qa_path.name}")
print(f"Columns: {list(qa_rows[0].keys())}")
print(f"\nFirst 5 questions:")
for row in qa_rows[:5]:
    print(f"  [{row['id']}] {row['question'][:90]}...")
    print(f"         expected: {row['answer_value']} ({row.get('answer_unit', '')})")

In [None]:
from kohakurag.types import NodeKind, StoredNode
from kohakurag.embeddings import average_embeddings
from kohakurag.pipeline import RAGPipeline
from kohakurag.datastore import InMemoryNodeStore

# Sample documents (sustainable AI topics that overlap with train_QA questions)
documents = [
    {
        "id": "patterson2021",
        "title": "Carbon Emissions and Large Neural Networks",
        "sentences": [
            "Training GPT-3 (175B parameters) was estimated to emit approximately 552 tonnes of CO2.",
            "Training GShard-600B used 24 MWh and produced 4.3 net tCO2e.",
            "Smaller models like Llama-2-7B require roughly 30x less compute.",
            "Techniques such as mixed-precision training and gradient checkpointing can further reduce energy consumption by 20-30%.",
        ],
    },
    {
        "id": "wu2021b",
        "title": "Sustainable AI and Data Center Efficiency",
        "sentences": [
            "Modern data centers consume approximately 1-2% of global electricity.",
            "Hyperscale data centers in 2020 achieved more than 40% higher efficiency compared to traditional data centers.",
            "Google reported a PUE (Power Usage Effectiveness) of 1.10 across its fleet in 2023.",
            "Liquid cooling systems can reduce energy usage by up to 40% compared to traditional air cooling.",
        ],
    },
    {
        "id": "li2025b",
        "title": "Water Consumption of AI Systems",
        "sentences": [
            "GPT-3 needs to drink a 500ml bottle of water for roughly 10 to 50 medium-length responses.",
            "The estimated total operational water consumption for training GPT-3 in Microsoft's U.S. data centers was 5.439 million liters.",
            "Microsoft committed to being carbon negative by 2030.",
            "Azure data centers in Sweden run on 100% renewable energy.",
        ],
    },
    {
        "id": "strubell2019",
        "title": "Energy and Policy Considerations for Deep Learning",
        "sentences": [
            "Authors should report training time and computational resources required for reproducibility.",
            "Tracking the runtime of a training job is an important step for estimating compute cost in GPU-based or cloud environments.",
            "The financial cost of training a large transformer model can exceed $1 million.",
        ],
    },
]

# Build hierarchical nodes: document -> paragraph -> sentence
# The pipeline expects parent nodes to exist when walking the hierarchy
nodes = []

for doc in documents:
    doc_id = doc["id"]
    
    # Embed all sentences at once
    sent_vecs = await embedder.embed(doc["sentences"])
    
    # Create sentence nodes
    sent_node_ids = []
    for s_idx, (sent, vec) in enumerate(zip(doc["sentences"], sent_vecs)):
        sent_id = f"{doc_id}:p0:s{s_idx}"
        sent_node_ids.append(sent_id)
        nodes.append(StoredNode(
            node_id=sent_id,
            parent_id=f"{doc_id}:p0",
            kind=NodeKind.SENTENCE,
            title=doc["title"],
            text=sent,
            metadata={"document_id": doc_id},
            embedding=vec,
            child_ids=[],
        ))
    
    # Create paragraph node (parent of sentences) with averaged embedding
    para_vec = average_embeddings([v for v in sent_vecs])
    nodes.append(StoredNode(
        node_id=f"{doc_id}:p0",
        parent_id=doc_id,
        kind=NodeKind.PARAGRAPH,
        title=doc["title"],
        text=" ".join(doc["sentences"]),
        metadata={"document_id": doc_id},
        embedding=para_vec,
        child_ids=sent_node_ids,
    ))
    
    # Create document node (root) with averaged embedding
    nodes.append(StoredNode(
        node_id=doc_id,
        parent_id=None,
        kind=NodeKind.DOCUMENT,
        title=doc["title"],
        text=doc["title"],
        metadata={"document_id": doc_id},
        embedding=para_vec,  # same as paragraph for single-paragraph docs
        child_ids=[f"{doc_id}:p0"],
    ))

# Create in-memory store and index
store = InMemoryNodeStore()
await store.upsert_nodes(nodes)
print(f"Indexed {len(nodes)} nodes ({len(documents)} docs) into in-memory store")
print(f"  Hierarchy: document -> paragraph -> sentences")

In [None]:
# Assemble pipeline with local components
pipeline = RAGPipeline(
    store=store,
    embedder=embedder,
    chat_model=chat,
    top_k=3,
)
print("Pipeline assembled (local embedder + local LLM + in-memory store)")

In [None]:
# Test retrieval with a real WattBot question
question = qa_rows[0]["question"]  # First question from train_QA.csv

result = await pipeline.retrieve(question, top_k=3)
print(f"Question: {question}")
print(f"Retrieved {len(result.matches)} matches:\n")
for i, match in enumerate(result.matches):
    print(f"  [{i+1}] score={match.score:.4f}  node={match.node.node_id}")
    print(f"      {match.node.text[:120]}...\n")

In [None]:
# Test full QA (retrieve + generate)
answer = await pipeline.answer(question)

print(f"Question: {answer['question']}\n")
print(f"Expected: {qa_rows[0]['answer_value']}")
print(f"\nResponse:\n{answer['response']}")

## Step 5 - Structured QA (JSON output)

This tests the `run_qa` method with the same prompt templates used in production.

In [None]:
import json

system_prompt = (
    "You must answer strictly based on the provided context snippets. "
    "Do NOT use external knowledge. If the context does not support an answer, "
    "output 'is_blank' for answer_value. Respond in JSON with keys: "
    "explanation, answer, answer_value, ref_id."
)

user_template = """Question: {question}

Context:
{context}

Additional info: {additional_info_json}

Return STRICT JSON:
- explanation: 1-2 sentences
- answer: short answer
- answer_value: numeric/categorical value or "is_blank"
- ref_id: list of document ids used

JSON Answer:"""

# Use a question from train_QA that should match our sample docs
# q009: "What were the net CO2e emissions from training the GShard-600B model?"
gshard_row = next(r for r in qa_rows if "GShard" in r["question"])

structured_result = await pipeline.run_qa(
    question=gshard_row["question"],
    system_prompt=system_prompt,
    user_template=user_template,
    additional_info={"answer_unit": gshard_row.get("answer_unit", "")},
    top_k=3,
)

print(f"Question:     {gshard_row['question']}")
print(f"Expected:     {gshard_row['answer_value']} ({gshard_row.get('answer_unit', '')})")
print(f"Answer:       {structured_result.answer.answer}")
print(f"Answer value: {structured_result.answer.answer_value}")
print(f"Ref IDs:      {structured_result.answer.ref_id}")
print(f"Explanation:  {structured_result.answer.explanation}")
print(f"\nRaw LLM output:\n{structured_result.raw_response[:500]}")

## Step 6 - Offline validation

Confirm no network calls were made by unsetting API keys and re-running.

In [None]:
import os

# Clear any API keys to prove we're fully local
for key in ["OPENROUTER_API_KEY", "OPENAI_API_KEY", "JINA_API_KEY"]:
    os.environ.pop(key, None)

# Re-run a query - should work without any API keys
offline_answer = await pipeline.answer(
    "What percentage of global electricity do data centers use?"
)
print(f"Offline response:\n{offline_answer['response']}")
print("\nOFFLINE VALIDATION PASSED - no API keys needed!")

## Summary

If all cells above ran successfully, your local HF pipeline is working:

| Component | Provider | Model |
|-----------|----------|-------|
| Embeddings | `LocalHFEmbeddingModel` | `BAAI/bge-base-en-v1.5` |
| LLM | `HuggingFaceLocalChatModel` | Configured above |
| Vector store | `InMemoryNodeStore` | (in-memory, no DB needed) |

**What was tested:**
- Steps 1-3: Individual component verification (imports, embeddings, LLM)
- Step 4: Full RAG pipeline with train_QA.csv questions
- Step 5: Structured JSON QA (production format)
- Step 6: Offline validation (no API keys)
- Step 7: Batch test with multiple WattBot questions
- Step 8a: Serial ensemble voting (single GPU)
- Step 8b: Parallel ensemble voting (multi-GPU with `CUDA_VISIBLE_DEVICES`)
- Step 8c: Batch ensemble across multiple questions

**Ensemble modes:**
- `ref_mode="union"` — vote on answer, union ref_ids from agreeing runs
- `ref_mode="intersection"` — vote on answer, intersect ref_ids
- `ref_mode="independent"` — vote answer and ref_id separately
- `ignore_blank=True` — filter out `is_blank` before voting (default)

To use with the full production pipeline (KVaultNodeStore + pre-indexed docs),
set `llm_provider = "hf_local"` and `embedding_model = "hf_local"` in your config.

## Step 7 - Batch test with multiple WattBot questions

Run several questions from `train_QA.csv` through the pipeline, including
ones that should match our sample docs and ones that won't (testing "is_blank").

In [None]:
# Pick questions that test different scenarios
sample_questions = [
    # Should match: hyperscale data centers (wu2021b)
    next(r for r in qa_rows if "Hyperscale" in r["question"]),
    # Should match: water consumption (li2025b)
    next(r for r in qa_rows if "water consumption" in r["question"].lower() and "training GPT-3" in r["question"]),
    # Should match: tracking runtime (strubell2019)
    next(r for r in qa_rows if "runtime" in r["question"].lower() and "training job" in r["question"].lower()),
    # Should NOT match: elephant (tests is_blank)
    next(r for r in qa_rows if "elephant" in r["question"].lower()),
]

print(f"Running {len(sample_questions)} WattBot questions through local pipeline...\n")
print("=" * 70)

for row in sample_questions:
    qid = row["id"]
    question = row["question"]
    expected = row["answer_value"]
    unit = row.get("answer_unit", "")

    result = await pipeline.run_qa(
        question=question,
        system_prompt=system_prompt,
        user_template=user_template,
        additional_info={"answer_unit": unit},
        top_k=3,
    )

    print(f"\n[{qid}] {question[:85]}...")
    print(f"  Expected:  {expected} ({unit})")
    print(f"  Got:       {result.answer.answer_value}")
    print(f"  Answer:    {result.answer.answer}")
    print(f"  Ref IDs:   {result.answer.ref_id}")
    print("-" * 70)

print("\nBatch WattBot test completed (pipeline ran without errors)")

## Step 8 - Ensemble voting with local HF models

KohakuRAG's competition-winning strategy uses **ensemble voting**: run the
same question through the pipeline N times (with temperature > 0 so outputs
vary), then pick the answer that appears most often via majority vote.

This step demonstrates two modes:
- **Serial** (single GPU) — runs N passes sequentially, safe for any setup
- **Parallel** (multi-GPU) — spawns workers across GPUs via `CUDA_VISIBLE_DEVICES`

The voting logic uses the same `wattbot_aggregate.py` strategies as production:
`union`, `intersection`, `independent`, `ref_priority`, `answer_priority`.

In [None]:
import asyncio
import json
import os
import time
from collections import Counter
from concurrent.futures import ProcessPoolExecutor
from dataclasses import dataclass


# ---------------------------------------------------------------------------
# Voting / aggregation helpers (same logic as wattbot_aggregate.py)
# ---------------------------------------------------------------------------

def majority_vote(values: list[str], ignore_blank: bool = True) -> str:
    """Pick the most-common value; first occurrence breaks ties.
    
    Args:
        values: candidate answers from N runs
        ignore_blank: drop 'is_blank' entries before voting (if non-blanks exist)
    """
    if not values:
        return "is_blank"
    if ignore_blank:
        non_blank = [v for v in values if v != "is_blank"]
        if non_blank:
            values = non_blank
    counter = Counter(values)
    max_count = counter.most_common(1)[0][1]
    # First-occurrence tiebreak
    for v in values:
        if counter[v] == max_count:
            return v
    return values[0]


def aggregate_ensemble(
    all_run_answers: list[dict],
    ref_mode: str = "union",
    ignore_blank: bool = True,
) -> dict:
    """Aggregate N answers for a single question.
    
    Args:
        all_run_answers: list of dicts, each with keys
            answer_value, answer, ref_id, explanation
        ref_mode: "union" | "intersection" | "independent"
        ignore_blank: filter is_blank before voting
        
    Returns:
        dict with voted answer_value, answer, ref_ids, per-run details
    """
    answer_values = [r.get("answer_value", "is_blank") for r in all_run_answers]
    best_value = majority_vote(answer_values, ignore_blank)

    # Collect ref_ids from runs that agree with the winning answer
    matching = [r for r in all_run_answers if r.get("answer_value") == best_value]
    if not matching:
        matching = all_run_answers  # fallback

    if ref_mode == "union":
        all_refs = set()
        for r in matching:
            refs = r.get("ref_id", [])
            if isinstance(refs, list):
                all_refs.update(refs)
            elif isinstance(refs, str) and refs != "is_blank":
                all_refs.add(refs)
        voted_refs = sorted(all_refs) if all_refs else ["is_blank"]
    elif ref_mode == "intersection":
        ref_sets = []
        for r in matching:
            refs = r.get("ref_id", [])
            if isinstance(refs, list):
                ref_sets.append(set(refs))
        if ref_sets:
            voted_refs = sorted(ref_sets[0].intersection(*ref_sets[1:]))
        else:
            voted_refs = ["is_blank"]
        if not voted_refs:
            voted_refs = ["is_blank"]
    else:  # independent
        ref_strs = [str(r.get("ref_id", "is_blank")) for r in all_run_answers]
        voted_refs = [majority_vote(ref_strs, ignore_blank)]

    # Use explanation from the first matching run
    best_answer = matching[0].get("answer", best_value)
    best_explanation = matching[0].get("explanation", "")

    return {
        "answer_value": best_value,
        "answer": best_answer,
        "ref_id": voted_refs,
        "explanation": best_explanation,
        "vote_counts": dict(Counter(answer_values)),
        "n_runs": len(all_run_answers),
        "agreement": Counter(answer_values).most_common(1)[0][1] / len(all_run_answers),
    }


print("Ensemble voting helpers loaded")

### 8a) Serial ensemble (single GPU)

Runs the pipeline N times with `temperature > 0` so each pass produces
slightly different outputs. Then aggregates with majority voting.

In [None]:
# ---------------------------------------------------------------------------
# Serial ensemble: N runs on the same GPU, one after another
# ---------------------------------------------------------------------------

ENSEMBLE_SIZE = 5         # Number of runs per question
ENSEMBLE_TEMPERATURE = 0.6  # Needs > 0 for diversity across runs

# Build a pipeline with temperature > 0 for the ensemble
# (reuses the same embedder + store from Step 4)
chat_ensemble = HuggingFaceLocalChatModel(
    model=LLM_MODEL_ID,
    dtype=LLM_DTYPE,
    max_new_tokens=256,
    temperature=ENSEMBLE_TEMPERATURE,
)

pipeline_ensemble = RAGPipeline(
    store=store,
    embedder=embedder,
    chat_model=chat_ensemble,
    top_k=3,
)

# Pick a question from train_QA
test_row = next(r for r in qa_rows if "GShard" in r["question"])
test_question = test_row["question"]
test_expected = test_row["answer_value"]
test_unit = test_row.get("answer_unit", "")

print(f"Question:  {test_question}")
print(f"Expected:  {test_expected} ({test_unit})")
print(f"Ensemble:  {ENSEMBLE_SIZE} runs @ temperature={ENSEMBLE_TEMPERATURE}")
print(f"\nRunning serial ensemble...")
print("-" * 60)

run_answers = []
t0 = time.time()

for i in range(ENSEMBLE_SIZE):
    result = await pipeline_ensemble.run_qa(
        question=test_question,
        system_prompt=system_prompt,
        user_template=user_template,
        additional_info={"answer_unit": test_unit},
        top_k=3,
    )
    answer_dict = {
        "answer_value": result.answer.answer_value,
        "answer": result.answer.answer,
        "ref_id": result.answer.ref_id,
        "explanation": result.answer.explanation,
    }
    run_answers.append(answer_dict)
    print(f"  Run {i+1}/{ENSEMBLE_SIZE}: answer_value={result.answer.answer_value!r}")

elapsed = time.time() - t0
print(f"\n{ENSEMBLE_SIZE} runs completed in {elapsed:.1f}s ({elapsed/ENSEMBLE_SIZE:.1f}s/run)")

# Aggregate with majority voting
voted = aggregate_ensemble(run_answers, ref_mode="union", ignore_blank=True)

print(f"\n{'=' * 60}")
print(f"ENSEMBLE RESULT (majority vote, ref_mode=union)")
print(f"{'=' * 60}")
print(f"  Voted answer:  {voted['answer_value']}")
print(f"  Expected:      {test_expected}")
print(f"  Agreement:     {voted['agreement']:.0%} ({voted['n_runs']} runs)")
print(f"  Vote counts:   {voted['vote_counts']}")
print(f"  Ref IDs:       {voted['ref_id']}")
print(f"  Explanation:   {voted['explanation'][:200]}")

### 8b) Parallel ensemble (multi-GPU)

When you have multiple GPUs (e.g., PowerEdge with 2x A100), each worker
gets its own GPU via `CUDA_VISIBLE_DEVICES`. Each process loads its own
model copy and runs independently — no VRAM contention.

**How it works:**
1. Spawns N worker **processes** (one per GPU, round-robin)
2. Each process sets `CUDA_VISIBLE_DEVICES` to its assigned GPU
3. Loads the model fresh (no shared memory between processes)
4. Runs the pipeline and returns the answer dict
5. Main process collects all results and runs majority vote

> **Tip:** On a 2-GPU server, set `N_GPUS=2` and `ENSEMBLE_SIZE=6`
> to run 3 passes per GPU in parallel.

In [None]:
# ---------------------------------------------------------------------------
# Parallel ensemble worker function (runs in a subprocess)
# ---------------------------------------------------------------------------

def _ensemble_worker(args: dict) -> dict:
    """Run a single ensemble pass in a subprocess with a specific GPU.
    
    This function is called by ProcessPoolExecutor. Each invocation:
    1. Pins to a specific GPU via CUDA_VISIBLE_DEVICES
    2. Loads embedder + LLM from scratch (subprocess has its own memory)
    3. Rebuilds the in-memory store + pipeline
    4. Runs run_qa and returns the answer dict
    
    Args:
        args: dict with keys:
            gpu_id, run_id, question, system_prompt, user_template,
            additional_info, llm_model_id, llm_dtype, temperature,
            embed_model_name, documents, top_k
    
    Returns:
        dict with answer_value, answer, ref_id, explanation, run_id, gpu_id
    """
    import os
    os.environ["CUDA_VISIBLE_DEVICES"] = str(args["gpu_id"])
    
    import asyncio
    import numpy as np
    
    from kohakurag.embeddings import LocalHFEmbeddingModel, average_embeddings
    from kohakurag.llm import HuggingFaceLocalChatModel
    from kohakurag.pipeline import RAGPipeline
    from kohakurag.datastore import InMemoryNodeStore
    from kohakurag.types import NodeKind, StoredNode
    
    async def _run():
        # Load models on this GPU
        embedder = LocalHFEmbeddingModel(model_name=args["embed_model_name"])
        chat = HuggingFaceLocalChatModel(
            model=args["llm_model_id"],
            dtype=args["llm_dtype"],
            max_new_tokens=256,
            temperature=args["temperature"],
        )
        
        # Rebuild the store (each subprocess needs its own)
        nodes = []
        for doc in args["documents"]:
            doc_id = doc["id"]
            sent_vecs = await embedder.embed(doc["sentences"])
            
            sent_node_ids = []
            for s_idx, (sent, vec) in enumerate(zip(doc["sentences"], sent_vecs)):
                sent_id = f"{doc_id}:p0:s{s_idx}"
                sent_node_ids.append(sent_id)
                nodes.append(StoredNode(
                    node_id=sent_id,
                    parent_id=f"{doc_id}:p0",
                    kind=NodeKind.SENTENCE,
                    title=doc["title"],
                    text=sent,
                    metadata={"document_id": doc_id},
                    embedding=vec,
                    child_ids=[],
                ))
            
            para_vec = average_embeddings([v for v in sent_vecs])
            nodes.append(StoredNode(
                node_id=f"{doc_id}:p0",
                parent_id=doc_id,
                kind=NodeKind.PARAGRAPH,
                title=doc["title"],
                text=" ".join(doc["sentences"]),
                metadata={"document_id": doc_id},
                embedding=para_vec,
                child_ids=sent_node_ids,
            ))
            nodes.append(StoredNode(
                node_id=doc_id,
                parent_id=None,
                kind=NodeKind.DOCUMENT,
                title=doc["title"],
                text=doc["title"],
                metadata={"document_id": doc_id},
                embedding=para_vec,
                child_ids=[f"{doc_id}:p0"],
            ))
        
        store = InMemoryNodeStore()
        await store.upsert_nodes(nodes)
        
        pipeline = RAGPipeline(
            store=store,
            embedder=embedder,
            chat_model=chat,
            top_k=args["top_k"],
        )
        
        result = await pipeline.run_qa(
            question=args["question"],
            system_prompt=args["system_prompt"],
            user_template=args["user_template"],
            additional_info=args["additional_info"],
            top_k=args["top_k"],
        )
        
        return {
            "answer_value": result.answer.answer_value,
            "answer": result.answer.answer,
            "ref_id": result.answer.ref_id,
            "explanation": result.answer.explanation,
            "run_id": args["run_id"],
            "gpu_id": args["gpu_id"],
        }
    
    return asyncio.run(_run())


print("Parallel ensemble worker defined")

In [None]:
# ---------------------------------------------------------------------------
# Run the parallel ensemble
# ---------------------------------------------------------------------------

# Configuration — adjust for your hardware
N_GPUS = torch.cuda.device_count()  # auto-detect
PARALLEL_ENSEMBLE_SIZE = 4           # total runs (distributed across GPUs)

print(f"Detected {N_GPUS} GPU(s)")

if N_GPUS < 2:
    print(
        "\nSkipping parallel ensemble (needs 2+ GPUs).\n"
        "The serial ensemble in Step 8a works on single-GPU setups.\n"
        "On your PowerEdge with 2 GPUs, this will run automatically."
    )
else:
    print(f"Running {PARALLEL_ENSEMBLE_SIZE} passes across {N_GPUS} GPUs...")
    
    # Build worker args — round-robin GPU assignment
    worker_args = []
    for run_id in range(PARALLEL_ENSEMBLE_SIZE):
        gpu_id = run_id % N_GPUS
        worker_args.append({
            "gpu_id": gpu_id,
            "run_id": run_id,
            "question": test_question,
            "system_prompt": system_prompt,
            "user_template": user_template,
            "additional_info": {"answer_unit": test_unit},
            "llm_model_id": LLM_MODEL_ID,
            "llm_dtype": LLM_DTYPE,
            "temperature": ENSEMBLE_TEMPERATURE,
            "embed_model_name": "BAAI/bge-base-en-v1.5",
            "documents": documents,  # serializable list of dicts
            "top_k": 3,
        })
    
    t0 = time.time()
    
    # Use ProcessPoolExecutor — each process gets its own GPU
    # max_workers = N_GPUS ensures one model per GPU at a time
    with ProcessPoolExecutor(max_workers=N_GPUS) as executor:
        parallel_results = list(executor.map(_ensemble_worker, worker_args))
    
    elapsed = time.time() - t0
    
    print(f"\n{PARALLEL_ENSEMBLE_SIZE} runs completed in {elapsed:.1f}s "
          f"({elapsed/PARALLEL_ENSEMBLE_SIZE:.1f}s/run effective, "
          f"{N_GPUS} GPUs)")
    
    for r in parallel_results:
        print(f"  Run {r['run_id']} (GPU {r['gpu_id']}): {r['answer_value']!r}")
    
    # Aggregate
    voted_parallel = aggregate_ensemble(parallel_results, ref_mode="union")
    
    print(f"\n{'=' * 60}")
    print(f"PARALLEL ENSEMBLE RESULT ({N_GPUS} GPUs)")
    print(f"{'=' * 60}")
    print(f"  Voted answer:  {voted_parallel['answer_value']}")
    print(f"  Expected:      {test_expected}")
    print(f"  Agreement:     {voted_parallel['agreement']:.0%}")
    print(f"  Vote counts:   {voted_parallel['vote_counts']}")
    print(f"  Ref IDs:       {voted_parallel['ref_id']}")
    print(f"  Speedup:       ~{N_GPUS}x vs serial")

### 8c) Batch ensemble over multiple questions

Run the serial ensemble across several train_QA questions and compare
voted answers to expected values.

In [None]:
# ---------------------------------------------------------------------------
# Batch ensemble: run N passes for each of several questions
# ---------------------------------------------------------------------------

BATCH_ENSEMBLE_SIZE = 3  # fewer runs per question to keep this fast

batch_questions = [
    next(r for r in qa_rows if "GShard" in r["question"]),
    next(r for r in qa_rows if "Hyperscale" in r["question"]),
    next(r for r in qa_rows if "water consumption" in r["question"].lower() and "training GPT-3" in r["question"]),
    next(r for r in qa_rows if "elephant" in r["question"].lower()),
]

print(f"Batch ensemble: {len(batch_questions)} questions x {BATCH_ENSEMBLE_SIZE} runs each\n")
print(f"{'ID':<8} {'Expected':<20} {'Voted':<20} {'Agree':>6}  Question")
print("-" * 100)

for row in batch_questions:
    qid = row["id"]
    question = row["question"]
    expected = row["answer_value"]
    unit = row.get("answer_unit", "")
    
    run_answers = []
    for _ in range(BATCH_ENSEMBLE_SIZE):
        result = await pipeline_ensemble.run_qa(
            question=question,
            system_prompt=system_prompt,
            user_template=user_template,
            additional_info={"answer_unit": unit},
            top_k=3,
        )
        run_answers.append({
            "answer_value": result.answer.answer_value,
            "answer": result.answer.answer,
            "ref_id": result.answer.ref_id,
            "explanation": result.answer.explanation,
        })
    
    voted = aggregate_ensemble(run_answers, ref_mode="union")
    match = "Y" if voted["answer_value"] == expected else " "
    
    print(
        f"{qid:<8} {expected:<20} {voted['answer_value']:<20} "
        f"{voted['agreement']:>5.0%}  {question[:50]}..."
    )

print(f"\nBatch ensemble complete")