# RAG Benchmark & Optimization

## Preparation - Loading dataset

### Import

In [None]:
import os
import gc
import numpy as np
import pandas as pd
import time, pprint, json, re

from sklearn.decomposition import PCA
# import umap
# import faiss
# import hnswlib
from sklearn.neighbors import NearestNeighbors
import plotly.graph_objects as go

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


### Load Embedding

In [None]:
import glob

In [None]:
def load_chunked_embeddings(folder):
    pkl_files = sorted(glob.glob(os.path.join(folder, "gte_Qwen2_7B_instruct_*.pkl")))
    df_list   = []
    for pf in pkl_files:
        df_ch = pd.read_pickle(pf)
        df_list.append(df_ch)
    if not df_list:
        print("No chunk files found in", folder)
        return pd.DataFrame(columns=["text","embedding"])

    df_full = pd.concat(df_list, ignore_index=True)
    return df_full

In [None]:
def load_questions(pkl_path):

    return pd.read_pickle(pkl_path)

In [None]:
def load_contexts(csv_path):

    df = pd.read_csv(csv_path, encoding='utf-8')
    return df['context'].tolist()  # or a Series

### Dimension Reduction

In [None]:
def reduce_embeddings(
    dataset_embs,
    question_embs,
    method_name,      # "pca", "truncate_first", "truncate_last", "truncate_four"
    target_dim,
    umap_kwargs=None,
    pca_kwargs=None
):
    if umap_kwargs is None: umap_kwargs = {}
    if pca_kwargs  is None: pca_kwargs  = {}

    method = method_name.lower()
    D = dataset_embs.shape[1]

    # PCA
    if method == "pca":
        pca_model  = PCA(n_components=target_dim, random_state=42, **pca_kwargs)
        dataset_red  = pca_model.fit_transform(dataset_embs)
        question_red = pca_model.transform(question_embs)

    # Truncation: first and last
    elif method == "truncate_first":
        if target_dim > D:
            raise ValueError(f"target_dim {target_dim} > original dim {D}")
        idx          = slice(0, target_dim)
        dataset_red  = dataset_embs[:, idx]
        question_red = question_embs[:, idx]

    elif method == "truncate_last":
        if target_dim > D:
            raise ValueError(f"target_dim {target_dim} > original dim {D}")
        idx          = slice(D - target_dim, D)
        dataset_red  = dataset_embs[:, idx]
        question_red = question_embs[:, idx]

    # Truncation: four‑slice
    elif method == "truncate_four":
        if target_dim > D:
            raise ValueError(f"target_dim {target_dim} > original dim {D}")
        if target_dim % 4 != 0:
            raise ValueError("truncate_four requires target_dim divisible by 4")

        seg  = target_dim // 4                     # width of each slice
        anchors = [0, 0.25, 0.50, 0.75]           # 0%, 25%, 50%, 75%
        cols = np.hstack([
            np.arange(int(a * D), int(a * D) + seg)
            for a in anchors
        ])
        dataset_red  = dataset_embs[:, cols]
        question_red = question_embs[:, cols]

    else:
        raise ValueError(f"Unknown method_name '{method_name}'")

    return dataset_red.astype(np.float32, copy=False), question_red.astype(np.float32, copy=False)

### Retrieval Functions

In [None]:
def knn_search(dataset_embs, question_embs, top_k=1):

    nn = NearestNeighbors(n_neighbors=top_k, algorithm='auto')
    nn.fit(dataset_embs)
    dist, idx = nn.kneighbors(question_embs, n_neighbors=top_k, return_distance=True)
    return idx  # shape (Q, top_k)

### Evaluation / Comparison Function

In [None]:
def evaluate_retrieval(indices, df_dataset, contexts):
    """
    indices: shape (Q,1) from knn or hnsw
    df_dataset: the big dataset with columns [text, embedding], index=0..N-1
    contexts: list of ground truth (the 'correct' text for question i)

    We do an exact match: retrieved_text == contexts[i].
    Returns how many matches or an accuracy count.
    """
    correct = 0
    Q = len(contexts)
    for i in range(Q):
        best_idx = indices[i,0]
        retrieved_text = df_dataset.iloc[best_idx]["text"]
        if retrieved_text == contexts[i]:
            correct += 1
    accuracy = correct / Q
    return accuracy

### Complete Process

In [None]:
def run_complete_experiment_with_time(
    df_dataset,           # DataFrame with [text, embedding], shape (N,2)
    dataset_embs,
    question_embs,
    contexts,
    method_name,
    target_dim,
    do_knn=False,
    do_hnsw=False,
    umap_kwargs=None,
    pca_kwargs=None
):
    # print(f"\n=== {method_name.upper()}  |  dim={target_dim} ===")

    # 2) Dimension reduction
    t0 = time.perf_counter()
    ds_red, qs_red = reduce_embeddings(
        dataset_embs,
        question_embs,
        method_name=method_name,
        target_dim=target_dim,
        umap_kwargs=umap_kwargs,
        pca_kwargs=pca_kwargs
    )
    dt_reduce_sec = time.perf_counter() - t0

    # print(f"Dimension reduction: {dt_reduce_sec:.2f} sec")

    # 3) KNN retrieval
    if do_knn:
        t0 = time.perf_counter()
        idx_knn = knn_search(ds_red, qs_red, top_k=1)
        dt_knn_sec  = time.perf_counter() - t0

        knn_acc = evaluate_retrieval(idx_knn, df_dataset, contexts)
        # print(f"KNN   time: {dt_knn_sec:.2f} sec   |   acc: {knn_acc*100:.2f}%\n")

        knn_accuracy = round(knn_acc*100, 4)

    # 4) HNSW retrieval
    if do_hnsw:
        idx_hnsw, dt_hnsw_build_sec, dt_hnsw_search_sec = hnswlib_search(ds_red, qs_red)
        # dt_hnsw_build_min = dt_hnsw_build_sec / 60
        # dt_hnsw_search_min = dt_hnsw_search_sec / 60
        hnsw_acc = evaluate_retrieval(idx_hnsw, df_dataset, contexts)
        print(f"HNSW Build time: {dt_hnsw_build_sec} sec   |   HNSW Search time: {dt_hnsw_search_sec} sec   |   acc: {hnsw_acc*100:.2f}%\n")
        # print(f"HNSW Accuracy, dim={target_dim}, method={method_name}: {hnsw_acc*100:.4f}%")
        hnsw_accuracy = round(hnsw_acc*100, 4)

    del ds_red, qs_red
    gc.collect()


    return {
        "dt_reduce": round(dt_reduce_sec, 4),
        "knn_accuracy": knn_accuracy if do_knn else None,
        "dt_knn": round(dt_knn_sec, 4) if do_knn else None,
        "hnsw_accuracy": hnsw_accuracy if do_hnsw else None,
        "dt_hnsw_build": round(dt_hnsw_build_sec, 4) if do_hnsw else None,
        "dt_hnsw_search": round(dt_hnsw_search_sec, 4) if do_hnsw else None
    }



### Load data

In [None]:
# dataset_folder = "/content/drive/MyDrive/Colab Notebooks/CMPE 295 RAG System/MyProject/embeddings/Linq/4096/dataset"
ALIBABANLP_EMB_PATH  = '/content/drive/MyDrive/Colab Notebooks/CMPE 295 RAG System/MyProject/embeddings/Alibaba_NLP_3584/dataset'
df_dataset = load_chunked_embeddings(ALIBABANLP_EMB_PATH)
print("Dataset loaded:", df_dataset.shape)

Dataset loaded: (1000000, 2)


In [None]:
QUESTIONS_EMB_PATH = '/content/drive/MyDrive/Colab Notebooks/CMPE 295 RAG System/MyProject/embeddings/Alibaba_NLP_3584/questions/gte_Qwen2_7B_instruct_questions.pkl'
df_questions = load_questions(QUESTIONS_EMB_PATH)
print("Questions loaded:", df_questions.shape)

Questions loaded: (2470, 2)


In [None]:
contexts_csv = "/content/drive/MyDrive/Colab Notebooks/CMPE 295 RAG System/MyProject/qnc/combined/context_final.csv"
contexts = load_contexts(contexts_csv)
print("Contexts loaded:", len(contexts))

Contexts loaded: 2470


In [None]:
dataset_embs = np.vstack(df_dataset["embedding"].values).astype("float32")
question_embs= np.vstack(df_questions["embedding"].values).astype("float32")

In [None]:
def plot_accuracy_dicts(acc_dicts: dict, title="Accuracy vs Dimension"):

    # union of all dimensions, sorted high→low
    all_dims = sorted({d for dct in acc_dicts.values() for d in dct}, reverse=True)

    def y_for(label):
        return [acc_dicts[label].get(d, None) for d in all_dims]

    fig = go.Figure()
    marker_cycle = ["circle", "square", "triangle-up", "diamond", "cross"]
    for i, label in enumerate(acc_dicts):
        y_vals = y_for(label)
        fig.add_trace(
            go.Scatter(
                x=all_dims,
                y=y_vals,
                mode="markers+lines",
                marker=dict(symbol=marker_cycle[i % len(marker_cycle)], size=8),
                name=label,
                text=[
                    f"Dim {d}<br>{label}<br>{acc:.2f}%" if acc is not None
                    else f"Dim {d}<br>{label}<br>(missing)"
                    for d, acc in zip(all_dims, y_vals)
                ],
                hoverinfo="text",
                connectgaps=True
            )
        )

    fig.update_layout(
        title=title,
        xaxis_title="Embedding Dimension",
        yaxis_title="Accuracy (%)",
        yaxis=dict(range=[0, 100]),
        template="plotly_white",
        hovermode="closest",
        legend_title="Method"
    )
    fig.show()

## Preparation - Progressive KNN

In [None]:
def evaluate_retrieval(indices, df_dataset, contexts):

    if indices.ndim == 2:
        indices = indices.ravel()          # (Q,)
    correct = 0
    for i, row_id in enumerate(indices):
        if df_dataset.iloc[row_id]["text"] == contexts[i]:
            correct += 1
    return correct / len(contexts)


In [None]:
from typing import List, Dict

In [None]:
def evaluate_exact(indices, df_dataset, contexts):
    hits = sum(df_dataset.iloc[idx]["text"] == ctx
               for idx, ctx in zip(indices, contexts))
    return hits / len(contexts)

In [None]:
def progressive_knn_sklearn(
        ds_embs: np.ndarray,
        qs_embs: np.ndarray,
        df_dataset, contexts,
        algorithm = "brute",
        start_dim: int      = 64,
        start_k: int        = 1000,
        max_dim: int        = 512,
        step_factor: int    = 2,      # dim *= step_factor each loop
        step_k: int         = 2,      # how much k we keep, e.g. 2 -> k = k // step_k
        step_add: int       = None,   # add to dim
        verbose: bool       = True):

    if ((step_add is not None) and (step_factor is not None)) or ((step_add is None) and (step_factor is None)):
        raise ValueError("Choose either step_add or step factor")
    if step_add is None and step_factor <= 1:
        raise ValueError("step_factor must be > 1 when step_add is None")

    N, D = ds_embs.shape
    Q    = qs_embs.shape[0]

    def build_index(mat):
        # brute-force
        # return NearestNeighbors(n_neighbors=1, algorithm="brute", metric="euclidean").fit(mat)
        # auto
        return NearestNeighbors(n_neighbors=1, algorithm=algorithm, metric="euclidean").fit(mat)

    # --------------------------------------------------------
    # 1) initial global search on first start_dim dimensions
    t0 = time.perf_counter()
    index0 = build_index(ds_embs[:, :start_dim])
    _, I = index0.kneighbors(qs_embs[:, :start_dim], n_neighbors=start_k)
    pools = I.copy()                 # shape (Q, start_k)
    cand_set = np.unique(I.ravel())
    t_init = time.perf_counter() - t0
    if verbose:
        print(f"[init 0:{start_dim}] unique rows = {cand_set.size:,}")


    dim = start_dim
    k    = start_k
    t_slices = 0.0

    # ----------------------------------------
    # 2) loop
    while dim < max_dim:
        if step_add is None:
            new_dim = min(dim * step_factor, max_dim)
        else:
            new_dim = min(dim + step_add, max_dim)

        if new_dim >= max_dim:
            break

        dim = new_dim
        k   = max(1, k // step_k)          # halve pool size

        if verbose:
            print(f"[slice 0:{dim}] k={k}")

        t0 = time.perf_counter()
        # build index on current candidate rows
        cand_mat = ds_embs[cand_set][:, :dim]
        idx_local = build_index(cand_mat)

        # query all questions at once
        _, Iq = idx_local.kneighbors(qs_embs[:, :dim], n_neighbors=k)

        # map local → global row IDs
        pools = cand_set[Iq]

        # union of all rows for next round
        cand_set = np.unique(pools.ravel())
        t_slices += time.perf_counter() - t0

        if verbose:
            print(f"           candidates → {cand_set.size:,}")

    # ------------------------------------------------
    # 3) final 1-NN on remaining rows (0:max_dim)
    t0 = time.perf_counter()
    final_idx = build_index(ds_embs[cand_set][:, :max_dim])
    _, I_final = final_idx.kneighbors(qs_embs[:, :max_dim], n_neighbors=1)
    t_final = time.perf_counter() - t0
    row_ids = cand_set[I_final.ravel()]

    # 4) evaluate
    acc_pct = round(100 * evaluate_exact(row_ids, df_dataset, contexts), 4)
    return {
        "accuracy_pct": acc_pct,
        "final_pool"  : int(cand_set.size),
        "t_total_s": round(t_init+t_slices+t_final, 2)
    }

## Experiment

In [None]:
drive_root = "/content/drive/MyDrive/Colab Notebooks/CMPE 295 RAG System"

def save_json(obj, filename):
    path = f"{drive_root}/{filename}"
    with open(path, "w") as f:
        # json.dump(obj, f, indent=2)
        json.dump({str(k): v for k,v in obj.items()}, f, indent=2)
    print("✔  saved →", path)

### Truncation and KNN

In [None]:
truncate_dims=[3584, 3072, 2048, 1024, 512, 504, 496, 488, 480, 472, 464, 456, 448, 440, 432, 424, 416, 408, 400, 392, 384, 376, 368, 360, 352, 344, 336, 328, 320, 312, 304, 296, 288, 280, 272, 264, 256, 248, 240, 232, 224, 216, 208, 200, 192, 184, 176, 168, 160, 152, 144, 136, 128, 120, 112, 104, 96, 88, 80, 72, 64, 56, 48, 40, 32, 16]
knn_results = {}

#### Initial Run

In [None]:
for dim in truncate_dims:
    times     = []
    accs      = []

    for run in range(10):
        res = run_complete_experiment_with_time(
            df_dataset    = df_dataset,
            dataset_embs  = dataset_embs,
            question_embs = question_embs,
            contexts      = contexts,
            method_name   = "truncate_first",
            target_dim    = dim,
            do_knn        = True,
            do_hnsw       = False
        )
        times.append(res["dt_knn"])
        accs.append(res["knn_accuracy"])
        gc.collect()

    # -------- verify accuracy constant ------------
    uniq_acc = set(accs)
    if len(uniq_acc) != 1:
        print(f"accuracy drift at {dim}d → {uniq_acc}")
    acc_pct = round(accs[0], 4)

    # -------- median timing -----------------------
    median_time = round(float(np.median(times)), 4)

    print(f"Dim {dim:4d} | 10 runs (s): {sorted(round(t,4) for t in times)} "
          f"→ median={median_time:.4f}s | acc={acc_pct:.4f}%")

    knn_results[dim] = {
        "accuracy_pct" : acc_pct,
        "median_time_s": median_time,
        # "times_s"      : [round(t,4) for t in times]   # full trace if needed
    }

print("\n=== Summary (median time) ===")
for d in sorted(knn_results, reverse=True):
    r = knn_results[d]
    print(f"{d:4d} d  |  {r['accuracy_pct']:7.4f}%  |  {r['median_time_s']:8.4f}s")

Dim 3584 | 10 runs (s): [98.5961, 98.801, 99.1897, 99.2253, 99.336, 99.3803, 100.0219, 100.3155, 100.4691, 100.9872] → median=99.3581s | acc=95.0202%
Dim 3072 | 10 runs (s): [86.0599, 86.7163, 86.8852, 87.2316, 87.4144, 87.5882, 87.6553, 87.7125, 87.7228, 88.1838] → median=87.5013s | acc=94.9798%
Dim 2048 | 10 runs (s): [56.997, 57.2135, 57.3249, 57.4445, 57.4722, 57.5043, 57.5077, 57.748, 57.7928, 58.0219] → median=57.4883s | acc=94.8178%
Dim 1024 | 10 runs (s): [29.5732, 29.5801, 29.6115, 29.6506, 29.6599, 29.7328, 29.8109, 29.849, 29.8504, 29.8842] → median=29.6964s | acc=94.4939%
Dim  512 | 10 runs (s): [16.2359, 16.2738, 16.3007, 16.4075, 16.432, 16.4523, 16.4963, 16.5404, 16.5588, 16.5724] → median=16.4421s | acc=93.8057%
Dim  504 | 10 runs (s): [15.6833, 15.7181, 15.7187, 15.747, 15.7502, 15.7692, 15.7708, 15.7761, 15.7881, 15.9188] → median=15.7597s | acc=93.8866%
Dim  496 | 10 runs (s): [15.4512, 15.4717, 15.4963, 15.5499, 15.6168, 15.6298, 15.7222, 15.7654, 15.8014, 15.8141] 

#### Continueous run

In [None]:
log_text = """
Dim 3584 | 10 runs (s): [98.5961, 98.801, 99.1897, 99.2253, 99.336, 99.3803, 100.0219, 100.3155, 100.4691, 100.9872] → median=99.3581s | acc=95.0202%
Dim 3072 | 10 runs (s): [86.0599, 86.7163, 86.8852, 87.2316, 87.4144, 87.5882, 87.6553, 87.7125, 87.7228, 88.1838] → median=87.5013s | acc=94.9798%
Dim 2048 | 10 runs (s): [56.997, 57.2135, 57.3249, 57.4445, 57.4722, 57.5043, 57.5077, 57.748, 57.7928, 58.0219] → median=57.4883s | acc=94.8178%
Dim 1024 | 10 runs (s): [29.5732, 29.5801, 29.6115, 29.6506, 29.6599, 29.7328, 29.8109, 29.849, 29.8504, 29.8842] → median=29.6964s | acc=94.4939%
Dim  512 | 10 runs (s): [16.2359, 16.2738, 16.3007, 16.4075, 16.432, 16.4523, 16.4963, 16.5404, 16.5588, 16.5724] → median=16.4421s | acc=93.8057%
Dim  504 | 10 runs (s): [15.6833, 15.7181, 15.7187, 15.747, 15.7502, 15.7692, 15.7708, 15.7761, 15.7881, 15.9188] → median=15.7597s | acc=93.8866%
Dim  496 | 10 runs (s): [15.4512, 15.4717, 15.4963, 15.5499, 15.6168, 15.6298, 15.7222, 15.7654, 15.8014, 15.8141] → median=15.6233s | acc=93.8057%
Dim  488 | 10 runs (s): [15.2037, 15.2578, 15.2668, 15.2745, 15.3309, 15.3498, 15.4306, 15.4317, 15.5069, 15.5605] → median=15.3404s | acc=93.7652%
Dim  480 | 10 runs (s): [14.8889, 14.8997, 14.9975, 15.0141, 15.0499, 15.0574, 15.0699, 15.1098, 15.1467, 15.1985] → median=15.0536s | acc=93.7247%
Dim  472 | 10 runs (s): [14.5879, 14.6479, 14.6859, 14.7607, 14.7785, 14.8168, 14.8425, 14.8713, 14.89, 15.0226] → median=14.7977s | acc=93.7652%
Dim  464 | 10 runs (s): [14.5497, 14.5882, 14.6482, 14.6722, 14.687, 14.7447, 14.7479, 14.8413, 14.9796, 15.1213] → median=14.7158s | acc=93.8057%
Dim  456 | 10 runs (s): [14.4664, 14.4816, 14.4982, 14.5535, 14.5573, 14.581, 14.6309, 14.6627, 14.6728, 14.6753] → median=14.5692s | acc=93.7652%
Dim  448 | 10 runs (s): [14.2048, 14.2834, 14.3134, 14.3624, 14.3937, 14.4, 14.4168, 14.4486, 14.4645, 14.4846] → median=14.3969s | acc=93.6842%
Dim  440 | 10 runs (s): [14.0753, 14.0891, 14.1456, 14.1815, 14.2217, 14.2545, 14.2664, 14.2836, 14.2958, 14.5102] → median=14.2381s | acc=93.5628%
Dim  432 | 10 runs (s): [13.8269, 13.8611, 13.8906, 13.9279, 13.931, 13.9546, 13.9628, 14.0422, 14.0854, 14.351] → median=13.9428s | acc=93.5223%
Dim  424 | 10 runs (s): [13.7673, 13.875, 13.9283, 13.9464, 13.9882, 14.021, 14.0299, 14.041, 14.0591, 14.0683] → median=14.0046s | acc=93.6032%
Dim  416 | 10 runs (s): [13.5113, 13.5763, 13.5872, 13.6636, 13.6711, 13.6742, 13.715, 13.7419, 13.7753, 13.8312] → median=13.6727s | acc=93.6032%
Dim  408 | 10 runs (s): [13.4591, 13.4819, 13.486, 13.5092, 13.5678, 13.5779, 13.5953, 13.6078, 13.6158, 13.6311] → median=13.5728s | acc=93.5628%
Dim  400 | 10 runs (s): [13.2636, 13.2919, 13.3339, 13.3351, 13.3786, 13.4024, 13.4073, 13.429, 13.4519, 13.5268] → median=13.3905s | acc=93.6032%
Dim  392 | 10 runs (s): [13.0054, 13.0158, 13.0221, 13.043, 13.0485, 13.058, 13.0715, 13.0716, 13.1578, 13.1924] → median=13.0533s | acc=93.5223%
Dim  384 | 10 runs (s): [12.6727, 12.6807, 12.7358, 12.7405, 12.8385, 12.8441, 12.88, 12.9338, 12.9409, 13.0092] → median=12.8413s | acc=93.4008%
Dim  376 | 10 runs (s): [12.2072, 12.3001, 12.3019, 12.3514, 12.4022, 12.4283, 12.4582, 12.5006, 12.5166, 12.6906] → median=12.4153s | acc=93.2794%
Dim  368 | 10 runs (s): [11.9974, 12.003, 12.059, 12.0842, 12.0937, 12.1033, 12.1038, 12.1167, 12.1689, 12.1987] → median=12.0985s | acc=93.2794%
Dim  360 | 10 runs (s): [11.7829, 11.8199, 11.822, 11.8401, 11.9608, 11.9768, 12.0176, 12.077, 12.0907, 12.1685] → median=11.9688s | acc=93.2389%
Dim  352 | 10 runs (s): [11.6514, 11.6531, 11.6847, 11.6893, 11.6976, 11.7434, 11.7986, 11.8156, 11.8367, 11.9119] → median=11.7205s | acc=93.1579%
Dim  344 | 10 runs (s): [11.3963, 11.4253, 11.4263, 11.4331, 11.4962, 11.538, 11.5711, 11.5807, 11.62, 11.6891] → median=11.5171s | acc=92.9555%
Dim  336 | 10 runs (s): [11.1637, 11.238, 11.3026, 11.3328, 11.368, 11.3922, 11.4234, 11.4921, 11.5174, 11.5184] → median=11.3801s | acc=92.8745%
Dim  328 | 10 runs (s): [11.0132, 11.1046, 11.1064, 11.1222, 11.1414, 11.1721, 11.2067, 11.2515, 11.2886, 11.3223] → median=11.1568s | acc=92.8745%
Dim  320 | 10 runs (s): [10.7636, 10.7952, 10.8004, 10.8235, 10.8795, 10.8833, 10.9048, 10.9172, 11.1553, 11.1705] → median=10.8814s | acc=92.7126%
Dim  312 | 10 runs (s): [10.4112, 10.6056, 10.6257, 10.6973, 10.7043, 10.7789, 10.801, 10.8165, 10.8248, 10.8853] → median=10.7416s | acc=92.5911%
Dim  304 | 10 runs (s): [10.3105, 10.3392, 10.462, 10.4622, 10.5315, 10.5439, 10.5689, 10.5969, 10.6186, 10.9144] → median=10.5377s | acc=92.3887%
Dim  296 | 10 runs (s): [10.1995, 10.245, 10.3828, 10.3868, 10.3947, 10.442, 10.4581, 10.4799, 10.4836, 10.5146] → median=10.4184s | acc=92.5506%
Dim  288 | 10 runs (s): [10.0292, 10.05, 10.0658, 10.1435, 10.1846, 10.1865, 10.2795, 10.2816, 10.337, 10.3875] → median=10.1855s | acc=93.0769%
Dim  280 | 10 runs (s): [9.6767, 9.8797, 9.9197, 9.9205, 9.962, 9.9644, 10.063, 10.0676, 10.1105, 10.1476] → median=9.9632s | acc=93.1579%
Dim  272 | 10 runs (s): [9.6072, 9.6148, 9.7516, 9.7561, 9.7765, 9.8097, 9.8101, 9.8265, 9.8677, 9.878] → median=9.7931s | acc=92.9555%
Dim  264 | 10 runs (s): [9.3566, 9.4343, 9.5161, 9.5223, 9.5341, 9.5529, 9.5867, 9.6196, 9.6285, 9.7512] → median=9.5435s | acc=93.0364%
Dim  256 | 10 runs (s): [9.2572, 9.3069, 9.3268, 9.3443, 9.366, 9.4019, 9.427, 9.5305, 9.5872, 9.6359] → median=9.3839s | acc=92.7935%
Dim  248 | 10 runs (s): [9.032, 9.0337, 9.0734, 9.0907, 9.1373, 9.1381, 9.1454, 9.1744, 9.2321, 9.2636] → median=9.1377s | acc=92.8745%
Dim  240 | 10 runs (s): [8.8112, 8.8638, 8.8775, 8.8779, 8.8955, 8.943, 8.9608, 8.9784, 9.0025, 9.1093] → median=8.9192s | acc=92.5911%
Dim  232 | 10 runs (s): [8.4606, 8.4638, 8.5028, 8.507, 8.5726, 8.5786, 8.6164, 8.6218, 8.635, 8.6746] → median=8.5756s | acc=92.4291%
Dim  224 | 10 runs (s): [8.4144, 8.4225, 8.4747, 8.5008, 8.516, 8.5539, 8.6318, 8.6637, 8.6835, 8.7037] → median=8.5350s | acc=92.2267%
Dim  216 | 10 runs (s): [8.0706, 8.1247, 8.172, 8.188, 8.1889, 8.2083, 8.2097, 8.3105, 8.3378, 8.3743] → median=8.1986s | acc=91.9838%
Dim  208 | 10 runs (s): [7.7105, 7.7788, 7.8385, 7.9352, 7.9647, 7.9821, 7.9951, 8.0476, 8.0731, 8.0871] → median=7.9734s | acc=91.8623%
"""

In [None]:
pattern = re.compile(
    r"Dim\s+(\d+)\s+\|.*?median=([\d\.]+)s\s+\|\s+acc=([\d\.]+)"
)

In [None]:
knn_results = {}
for dim, t, acc in pattern.findall(log_text):
    knn_results[int(dim)] = {
        "accuracy_pct" : round(float(acc), 4),
        "median_time_s": round(float(t),   4)
    }

print("Recovered ↓"); pprint.pprint(knn_results)

Recovered ↓
{208: {'accuracy_pct': 91.8623, 'median_time_s': 7.9734},
 216: {'accuracy_pct': 91.9838, 'median_time_s': 8.1986},
 224: {'accuracy_pct': 92.2267, 'median_time_s': 8.535},
 232: {'accuracy_pct': 92.4291, 'median_time_s': 8.5756},
 240: {'accuracy_pct': 92.5911, 'median_time_s': 8.9192},
 248: {'accuracy_pct': 92.8745, 'median_time_s': 9.1377},
 256: {'accuracy_pct': 92.7935, 'median_time_s': 9.3839},
 264: {'accuracy_pct': 93.0364, 'median_time_s': 9.5435},
 272: {'accuracy_pct': 92.9555, 'median_time_s': 9.7931},
 280: {'accuracy_pct': 93.1579, 'median_time_s': 9.9632},
 288: {'accuracy_pct': 93.0769, 'median_time_s': 10.1855},
 296: {'accuracy_pct': 92.5506, 'median_time_s': 10.4184},
 304: {'accuracy_pct': 92.3887, 'median_time_s': 10.5377},
 312: {'accuracy_pct': 92.5911, 'median_time_s': 10.7416},
 320: {'accuracy_pct': 92.7126, 'median_time_s': 10.8814},
 328: {'accuracy_pct': 92.8745, 'median_time_s': 11.1568},
 336: {'accuracy_pct': 92.8745, 'median_time_s': 11.380

In [None]:
done_dims = set(knn_results)

for dim in truncate_dims:
    if dim in done_dims:
        continue

    times     = []
    accs      = []

    for run in range(10):
        res = run_complete_experiment_with_time(
            df_dataset    = df_dataset,
            dataset_embs  = dataset_embs,
            question_embs = question_embs,
            contexts      = contexts,
            method_name   = "truncate_first",
            target_dim    = dim,
            do_knn        = True,
            do_hnsw       = False
        )
        times.append(res["dt_knn"])
        accs.append(res["knn_accuracy"])
        gc.collect()

    # -------- verify accuracy constant ------------
    uniq_acc = set(accs)
    if len(uniq_acc) != 1:
        print(f"accuracy drift at {dim}d → {uniq_acc}")
    acc_pct = round(accs[0], 4)

    # -------- median timing -----------------------
    median_time = round(float(np.median(times)), 4)

    print(f"Dim {dim:4d} | 10 runs (s): {sorted(round(t,4) for t in times)} "
          f"→ median={median_time:.4f}s | acc={acc_pct:.4f}%")

    knn_results[dim] = {
        "accuracy_pct" : acc_pct,
        "median_time_s": median_time,
        # "times_s"      : [round(t,4) for t in times]   # full trace if needed
    }

print("\n=== Summary (median time) ===")
for d in sorted(knn_results, reverse=True):
    r = knn_results[d]
    print(f"{d:4d} d  |  {r['accuracy_pct']:7.4f}%  |  {r['median_time_s']:8.4f}s")

Dim  200 | 10 runs (s): [7.5162, 7.5172, 7.5621, 7.6292, 7.6578, 7.7263, 7.7572, 7.7621, 7.7914, 7.8664] → median=7.6921s | acc=91.7409%
Dim  192 | 10 runs (s): [7.1926, 7.2628, 7.3883, 7.4155, 7.443, 7.4486, 7.4799, 7.5248, 7.5528, 7.6066] → median=7.4458s | acc=91.5789%
Dim  184 | 10 runs (s): [7.0385, 7.0756, 7.1119, 7.1401, 7.2211, 7.2404, 7.2559, 7.29, 7.3343, 7.358] → median=7.2308s | acc=91.3765%
Dim  176 | 10 runs (s): [6.8027, 6.8515, 6.8629, 6.9009, 6.9367, 6.9654, 7.0106, 7.1342, 7.1767, 7.2002] → median=6.9511s | acc=91.0526%
Dim  168 | 10 runs (s): [6.5749, 6.6374, 6.7385, 6.8073, 6.8434, 6.8462, 6.8656, 6.9242, 6.9403, 7.0034] → median=6.8448s | acc=90.8502%
Dim  160 | 10 runs (s): [6.3796, 6.3903, 6.4692, 6.5586, 6.5947, 6.6076, 6.6183, 6.6364, 6.6504, 6.6659] → median=6.6011s | acc=90.6073%
Dim  152 | 10 runs (s): [6.1534, 6.2522, 6.2751, 6.29, 6.3126, 6.3203, 6.4021, 6.4164, 6.4816, 6.4902] → median=6.3164s | acc=90.0810%
Dim  144 | 10 runs (s): [5.9818, 5.9945, 6.0303

In [None]:
save_json(knn_results, "truncate_first_knn_medians.json")

✔  saved → /content/drive/MyDrive/Colab Notebooks/CMPE 295 RAG System/truncate_first_knn_medians.json


In [None]:
print(knn_results)

{3584: {'accuracy_pct': 95.0202, 'median_time_s': 99.3581}, 3072: {'accuracy_pct': 94.9798, 'median_time_s': 87.5013}, 2048: {'accuracy_pct': 94.8178, 'median_time_s': 57.4883}, 1024: {'accuracy_pct': 94.4939, 'median_time_s': 29.6964}, 512: {'accuracy_pct': 93.8057, 'median_time_s': 16.4421}, 504: {'accuracy_pct': 93.8866, 'median_time_s': 15.7597}, 496: {'accuracy_pct': 93.8057, 'median_time_s': 15.6233}, 488: {'accuracy_pct': 93.7652, 'median_time_s': 15.3404}, 480: {'accuracy_pct': 93.7247, 'median_time_s': 15.0536}, 472: {'accuracy_pct': 93.7652, 'median_time_s': 14.7977}, 464: {'accuracy_pct': 93.8057, 'median_time_s': 14.7158}, 456: {'accuracy_pct': 93.7652, 'median_time_s': 14.5692}, 448: {'accuracy_pct': 93.6842, 'median_time_s': 14.3969}, 440: {'accuracy_pct': 93.5628, 'median_time_s': 14.2381}, 432: {'accuracy_pct': 93.5223, 'median_time_s': 13.9428}, 424: {'accuracy_pct': 93.6032, 'median_time_s': 14.0046}, 416: {'accuracy_pct': 93.6032, 'median_time_s': 13.6727}, 408: {'ac

### Progressive Truncation KNN

In [None]:
prog_stats = {}
dim_grid   = [64, 128, 256, 512, 1024, 2048, 3584]
start_k_ls = [2**i for i in range(2, 11)]
JSON_NAME  = "progressive_knn_medians.json"
JSON_PATH  = f"{drive_root}/{JSON_NAME}"

#### Initial Run

In [None]:
prog_stats = {}
dim_grid   = [64, 128, 256, 512, 1024, 2048, 3584]
start_k_ls = [2**i for i in range(2, 11)]

for startD in dim_grid:
    for maxD in dim_grid:
        if startD >= maxD:
            continue        # strictly ascending range

        for k0 in start_k_ls:
            key = (startD, maxD, k0)
            times, accs = [], []

            for run in range(10):
                res = progressive_knn_sklearn(
                    ds_embs     = dataset_embs,
                    qs_embs     = question_embs,
                    df_dataset  = df_dataset,
                    contexts    = contexts,
                    start_dim   = startD,
                    start_k     = k0,
                    max_dim     = maxD,
                    step_factor = 2,
                    step_k      = 2,
                    step_add    = None,
                    verbose     = False    # suppress per‑slice chatter
                )
                times.append(res["t_total_s"])
                accs.append(res["accuracy_pct"])
                gc.collect()

            # ---------- summarise one scenario -----------------
            times_sorted  = sorted(times)
            median_time   = round(float(np.median(times_sorted)), 4)
            acc_unique    = set(accs)
            # acc_final     = round(accs[0], 4) if len(acc_unique)==1 else sorted(acc_unique)[-1]
            if len(acc_unique) != 1:
                print(f"accuracy drift at [{startD:4d}->{maxD:<4d} | k={k0:4d}]")
            acc_final = round(accs[0], 4)

            print(f"[{startD:4d}->{maxD:<4d} | k={k0:4d}] "
                  f"times={times_sorted}  →  median={median_time:.4f}s  "
                  f"| acc={acc_final:.4f}%")

            prog_stats[key] = {
                "accuracy_pct" : acc_final,
                "median_time_s": median_time,
                # "times_s"      : [round(t, 4) for t in times_sorted]
            }

# ---------------- summary ------------------
print("\n=== MEDIAN TIMES (s) ===")
for (sD,mD,k0), stats in sorted(prog_stats.items()):
    print(f"{sD:4}->{mD:<4}  k={k0:4} | "
          f"{stats['accuracy_pct']:7.4f}%  | {stats['median_time_s']:8.4f}s")

# --------------- save to JSON for later plotting -----------------
save_json(prog_stats, "progressive_knn_medians.json")
print("\nSaved → progressive_knn_medians.json")

[  64->128  | k=   4] times=[3.93, 3.93, 3.95, 3.97, 4.02, 4.06, 4.07, 4.07, 4.17, 4.17]  →  median=4.0400s  | acc=83.7247%
[  64->128  | k=   8] times=[3.99, 3.99, 4.0, 4.04, 4.06, 4.08, 4.08, 4.18, 4.27, 4.27]  →  median=4.0700s  | acc=85.2227%
[  64->128  | k=  16] times=[4.22, 4.25, 4.25, 4.27, 4.28, 4.3, 4.36, 4.38, 4.45, 4.52]  →  median=4.2900s  | acc=86.4372%
[  64->128  | k=  32] times=[4.58, 4.58, 4.68, 4.71, 4.72, 4.78, 4.79, 4.79, 4.88, 4.91]  →  median=4.7500s  | acc=87.2470%
[  64->128  | k=  64] times=[5.36, 5.39, 5.39, 5.4, 5.43, 5.52, 5.54, 5.55, 5.56, 5.6]  →  median=5.4750s  | acc=87.8947%
[  64->128  | k= 128] times=[6.34, 6.35, 6.4, 6.43, 6.49, 6.53, 6.54, 6.59, 6.62, 6.62]  →  median=6.5100s  | acc=88.5020%
[  64->128  | k= 256] times=[7.63, 7.78, 8.0, 8.08, 8.09, 8.12, 8.14, 8.17, 8.2, 8.22]  →  median=8.1050s  | acc=88.5830%
[  64->128  | k= 512] times=[9.14, 9.27, 9.28, 9.32, 9.33, 9.36, 9.37, 9.39, 9.45, 9.47]  →  median=9.3450s  | acc=88.7449%
[  64->128  | k

#### Continue Run 1

In [None]:
if os.path.exists(JSON_PATH):
    with open(JSON_PATH) as f:
        prog_stats = {tuple(map(int, k.strip("()").split(","))): v
                      for k, v in json.load(f).items()}
    print(f"↩️  loaded {len(prog_stats)} scenarios from {JSON_NAME}")

↩️  loaded 43 scenarios from progressive_knn_medians.json


In [None]:
LOG_TEXT = """
[  64->128  | k=   4] times=[3.93, 3.93, 3.95, 3.97, 4.02, 4.06, 4.07, 4.07, 4.17, 4.17]  →  median=4.0400s  | acc=83.7247%
[  64->128  | k=   8] times=[3.99, 3.99, 4.0, 4.04, 4.06, 4.08, 4.08, 4.18, 4.27, 4.27]  →  median=4.0700s  | acc=85.2227%
[  64->128  | k=  16] times=[4.22, 4.25, 4.25, 4.27, 4.28, 4.3, 4.36, 4.38, 4.45, 4.52]  →  median=4.2900s  | acc=86.4372%
[  64->128  | k=  32] times=[4.58, 4.58, 4.68, 4.71, 4.72, 4.78, 4.79, 4.79, 4.88, 4.91]  →  median=4.7500s  | acc=87.2470%
[  64->128  | k=  64] times=[5.36, 5.39, 5.39, 5.4, 5.43, 5.52, 5.54, 5.55, 5.56, 5.6]  →  median=5.4750s  | acc=87.8947%
[  64->128  | k= 128] times=[6.34, 6.35, 6.4, 6.43, 6.49, 6.53, 6.54, 6.59, 6.62, 6.62]  →  median=6.5100s  | acc=88.5020%
[  64->128  | k= 256] times=[7.63, 7.78, 8.0, 8.08, 8.09, 8.12, 8.14, 8.17, 8.2, 8.22]  →  median=8.1050s  | acc=88.5830%
[  64->128  | k= 512] times=[9.14, 9.27, 9.28, 9.32, 9.33, 9.36, 9.37, 9.39, 9.45, 9.47]  →  median=9.3450s  | acc=88.7449%
[  64->128  | k=1024] times=[12.86, 12.91, 12.95, 12.95, 12.99, 13.0, 13.0, 13.01, 13.05, 13.09]  →  median=12.9950s  | acc=88.7449%
[  64->256  | k=   4] times=[3.95, 3.95, 3.96, 3.99, 4.02, 4.06, 4.07, 4.15, 4.17, 4.21]  →  median=4.0400s  | acc=84.8178%
[  64->256  | k=   8] times=[4.1, 4.12, 4.14, 4.15, 4.23, 4.23, 4.31, 4.33, 4.33, 4.42]  →  median=4.2300s  | acc=86.6397%
[  64->256  | k=  16] times=[4.42, 4.44, 4.46, 4.46, 4.5, 4.52, 4.6, 4.68, 4.73, 4.75]  →  median=4.5100s  | acc=88.4211%
[  64->256  | k=  32] times=[5.0, 5.03, 5.05, 5.08, 5.1, 5.18, 5.24, 5.28, 5.3, 5.33]  →  median=5.1400s  | acc=89.7571%
[  64->256  | k=  64] times=[6.06, 6.07, 6.07, 6.13, 6.14, 6.24, 6.29, 6.34, 6.36, 6.45]  →  median=6.1900s  | acc=90.5263%
[  64->256  | k= 128] times=[7.74, 7.77, 7.8, 7.81, 7.83, 7.83, 7.85, 7.87, 7.88, 7.9]  →  median=7.8300s  | acc=91.5385%
[  64->256  | k= 256] times=[10.59, 10.59, 10.61, 10.62, 10.65, 10.66, 10.68, 10.68, 10.79, 10.81]  →  median=10.6550s  | acc=91.9838%
[  64->256  | k= 512] times=[14.39, 14.47, 14.55, 14.59, 14.59, 14.6, 14.68, 14.84, 14.84, 14.84]  →  median=14.5950s  | acc=92.4696%
[  64->256  | k=1024] times=[19.84, 19.84, 19.88, 19.89, 19.94, 19.94, 19.96, 19.97, 19.97, 20.04]  →  median=19.9400s  | acc=92.6721%
[  64->512  | k=   4] times=[3.98, 4.01, 4.01, 4.02, 4.1, 4.12, 4.18, 4.18, 4.22, 4.28]  →  median=4.1100s  | acc=84.8178%
[  64->512  | k=   8] times=[4.21, 4.22, 4.25, 4.28, 4.3, 4.32, 4.32, 4.39, 4.5, 4.55]  →  median=4.3100s  | acc=86.8016%
[  64->512  | k=  16] times=[4.63, 4.64, 4.67, 4.67, 4.67, 4.77, 4.77, 4.92, 4.92, 4.93]  →  median=4.7200s  | acc=88.7854%
[  64->512  | k=  32] times=[5.39, 5.44, 5.46, 5.52, 5.54, 5.67, 5.72, 5.72, 5.72, 5.79]  →  median=5.6050s  | acc=90.1215%
[  64->512  | k=  64] times=[6.8, 6.88, 7.01, 7.03, 7.04, 7.06, 7.07, 7.1, 7.13, 7.22]  →  median=7.0500s  | acc=91.0526%
[  64->512  | k= 128] times=[9.17, 9.28, 9.31, 9.32, 9.35, 9.37, 9.44, 9.49, 9.5, 9.59]  →  median=9.3600s  | acc=92.1457%
[  64->512  | k= 256] times=[13.1, 13.11, 13.12, 13.19, 13.19, 13.23, 13.26, 13.27, 13.29, 13.36]  →  median=13.2100s  | acc=92.7126%
[  64->512  | k= 512] times=[18.72, 18.79, 18.97, 19.02, 19.03, 19.05, 19.11, 19.11, 19.12, 19.15]  →  median=19.0400s  | acc=93.3198%
[  64->512  | k=1024] times=[26.87, 27.0, 27.09, 27.16, 27.2, 27.22, 27.22, 27.26, 27.35, 27.36]  →  median=27.2100s  | acc=93.5628%
[  64->1024 | k=   4] times=[4.32, 4.34, 4.34, 4.36, 4.38, 4.4, 4.43, 4.49, 4.5, 4.53]  →  median=4.3900s  | acc=84.8178%
[  64->1024 | k=   8] times=[4.52, 4.53, 4.55, 4.56, 4.59, 4.6, 4.62, 4.64, 4.64, 4.71]  →  median=4.5950s  | acc=86.8016%
[  64->1024 | k=  16] times=[4.94, 5.02, 5.02, 5.05, 5.06, 5.11, 5.11, 5.12, 5.12, 5.22]  →  median=5.0850s  | acc=89.1498%
[  64->1024 | k=  32] times=[5.91, 5.95, 5.97, 5.98, 6.01, 6.02, 6.04, 6.07, 6.08, 6.1]  →  median=6.0150s  | acc=90.5263%
[  64->1024 | k=  64] times=[7.71, 7.72, 7.78, 7.79, 7.81, 7.83, 7.83, 7.83, 7.83, 7.84]  →  median=7.8200s  | acc=91.4980%
[  64->1024 | k= 128] times=[10.26, 10.29, 10.35, 10.47, 10.49, 10.5, 10.51, 10.51, 10.54, 10.82]  →  median=10.4950s  | acc=92.6316%
[  64->1024 | k= 256] times=[13.67, 13.71, 13.81, 13.82, 13.82, 14.01, 14.09, 14.1, 14.18, 14.99]  →  median=13.9150s  | acc=93.2794%
[  64->1024 | k= 512] times=[20.37, 20.6, 20.62, 20.65, 20.73, 20.79, 20.8, 20.81, 20.83, 20.86]  →  median=20.7600s  | acc=93.9271%
[  64->1024 | k=1024] times=[31.49, 32.74, 33.44, 33.58, 33.65, 33.8, 33.81, 34.01, 34.16, 34.82]  →  median=33.7250s  | acc=94.2105%
[  64->2048 | k=   4] times=[4.31, 4.34, 4.36, 4.36, 4.37, 4.42, 4.44, 4.51, 4.56, 4.59]  →  median=4.3950s  | acc=84.8178%
[  64->2048 | k=   8] times=[4.51, 4.52, 4.55, 4.59, 4.62, 4.65, 4.77, 4.79, 4.88, 4.91]  →  median=4.6350s  | acc=86.8016%
[  64->2048 | k=  16] times=[4.86, 4.9, 4.91, 4.92, 4.93, 4.94, 4.98, 5.03, 5.1, 5.1]  →  median=4.9350s  | acc=89.1498%
[  64->2048 | k=  32] times=[5.74, 5.77, 5.84, 5.91, 5.95, 6.0, 6.01, 6.02, 6.05, 6.12]  →  median=5.9750s  | acc=90.7287%
[  64->2048 | k=  64] times=[7.56, 7.56, 7.59, 7.6, 7.68, 7.7, 7.74, 7.74, 7.76, 7.76]  →  median=7.6900s  | acc=91.7409%
[  64->2048 | k= 128] times=[10.64, 10.65, 10.72, 10.72, 10.72, 10.75, 10.76, 10.78, 10.85, 11.64]  →  median=10.7350s  | acc=92.7935%
[  64->2048 | k= 256] times=[17.14, 17.18, 17.19, 17.2, 17.34, 17.38, 17.51, 17.56, 17.61, 17.69]  →  median=17.3600s  | acc=93.5223%
"""

In [None]:
if LOG_TEXT.strip():
    pat = re.compile(
        r"\[\s*(\d+)->(\d+)\s*\|\s*k=\s*(\d+)\]\s*times=\[([0-9\.,\s]+)\]"
        r".*?median=([\d\.]+)s.*?acc=([\d\.]+)%"
    )
    added = 0
    for sD, mD, k0, t_list, med, acc in pat.findall(LOG_TEXT):
        key = (int(sD), int(mD), int(k0))
        if key not in prog_stats:
            times_sorted = sorted(round(float(x), 4)
                                  for x in t_list.split(","))
            prog_stats[key] = {
                "accuracy_pct" : round(float(acc), 4),
                "median_time_s": round(float(med), 4),
                # "times_s"      : times_sorted
            }
            added += 1
    # if added:
    #     save_json(prog_stats, JSON_PATH)          # persist immediately
    #     print(f"imported {added} scenarios from LOG_TEXT")

In [None]:
print(prog_stats)

{(64, 128, 4): {'accuracy_pct': 83.7247, 'median_time_s': 4.04}, (64, 128, 8): {'accuracy_pct': 85.2227, 'median_time_s': 4.07}, (64, 128, 16): {'accuracy_pct': 86.4372, 'median_time_s': 4.29}, (64, 128, 32): {'accuracy_pct': 87.247, 'median_time_s': 4.75}, (64, 128, 64): {'accuracy_pct': 87.8947, 'median_time_s': 5.475}, (64, 128, 128): {'accuracy_pct': 88.502, 'median_time_s': 6.51}, (64, 128, 256): {'accuracy_pct': 88.583, 'median_time_s': 8.105}, (64, 128, 512): {'accuracy_pct': 88.7449, 'median_time_s': 9.345}, (64, 128, 1024): {'accuracy_pct': 88.7449, 'median_time_s': 12.995}, (64, 256, 4): {'accuracy_pct': 84.8178, 'median_time_s': 4.04}, (64, 256, 8): {'accuracy_pct': 86.6397, 'median_time_s': 4.23}, (64, 256, 16): {'accuracy_pct': 88.4211, 'median_time_s': 4.51}, (64, 256, 32): {'accuracy_pct': 89.7571, 'median_time_s': 5.14}, (64, 256, 64): {'accuracy_pct': 90.5263, 'median_time_s': 6.19}, (64, 256, 128): {'accuracy_pct': 91.5385, 'median_time_s': 7.83}, (64, 256, 256): {'ac

In [None]:
save_json(prog_stats, JSON_NAME)

✔  saved → /content/drive/MyDrive/Colab Notebooks/CMPE 295 RAG System/progressive_knn_medians.json


In [None]:
for startD in dim_grid:
    for maxD in dim_grid:
        if startD >= maxD:         # strictly ascending
            continue
        for k0 in start_k_ls:
            key = (startD, maxD, k0)
            if key in prog_stats:  # already done or ingested
                continue

            times, accs = [], []
            for _ in range(10):
                res = progressive_knn_sklearn(
                    ds_embs     = dataset_embs,
                    qs_embs     = question_embs,
                    df_dataset  = df_dataset,
                    contexts    = contexts,
                    start_dim   = startD,
                    start_k     = k0,
                    max_dim     = maxD,
                    step_factor = 2,
                    step_k      = 2,
                    step_add    = None,
                    verbose     = False
                )
                times.append(res["t_total_s"])
                accs.append(res["accuracy_pct"])
                gc.collect()

            times_sorted = sorted(round(t, 4) for t in times)
            median_time  = round(float(np.median(times_sorted)), 4)
            acc_unique   = set(accs)
            if len(acc_unique) != 1:
                print(f"accuracy drift at [{startD:4d}->{maxD:<4d} | k={k0:4d}]")
            acc_final = round(accs[0], 4)

            # ---- identical print format to your original ---------
            print(f"[{startD:4d}->{maxD:<4d} | k={k0:4d}] "
                  f"times={times_sorted}  →  median={median_time:.4f}s  "
                  f"| acc={acc_final:.4f}%")

            prog_stats[key] = {
                "accuracy_pct" : acc_final,
                "median_time_s": median_time,
                # "times_s"      : times_sorted
            }
            save_json(prog_stats, JSON_NAME)   # overwrite after each scenario

# ==============================================================
print("\n=== MEDIAN TIMES (s) ===")
for (sD, mD, k0), st in sorted(prog_stats.items()):
    print(f"{sD:4}->{mD:<4}  k={k0:4} | "
          f"{st['accuracy_pct']:7.4f}%  | {st['median_time_s']:8.4f}s")

[  64->2048 | k= 512] times=[24.89, 25.0, 25.05, 25.09, 25.15, 25.18, 25.35, 25.44, 25.56, 26.03]  →  median=25.1650s  | acc=94.2105%
✔  saved → /content/drive/MyDrive/Colab Notebooks/CMPE 295 RAG System/progressive_knn_medians.json
[  64->2048 | k=1024] times=[41.24, 41.36, 41.98, 41.99, 42.07, 42.2, 42.21, 42.58, 44.1, 47.19]  →  median=42.1350s  | acc=94.6154%
✔  saved → /content/drive/MyDrive/Colab Notebooks/CMPE 295 RAG System/progressive_knn_medians.json
[  64->3584 | k=   4] times=[4.56, 4.62, 4.64, 4.65, 4.65, 4.67, 4.67, 4.68, 4.78, 4.84]  →  median=4.6600s  | acc=84.8178%
✔  saved → /content/drive/MyDrive/Colab Notebooks/CMPE 295 RAG System/progressive_knn_medians.json
[  64->3584 | k=   8] times=[4.86, 4.89, 4.89, 4.91, 4.91, 4.93, 4.98, 4.99, 5.02, 5.05]  →  median=4.9200s  | acc=86.8016%
✔  saved → /content/drive/MyDrive/Colab Notebooks/CMPE 295 RAG System/progressive_knn_medians.json
[  64->3584 | k=  16] times=[5.34, 5.37, 5.39, 5.41, 5.41, 5.41, 5.53, 5.56, 5.61, 5.65] 

#### Continue Run 2

In [None]:
if os.path.exists(JSON_PATH):
    with open(JSON_PATH) as f:
        prog_stats = {tuple(map(int, k.strip("()").split(","))): v
                      for k, v in json.load(f).items()}
    print(f"↩️  loaded {len(prog_stats)} scenarios from {JSON_NAME}")

↩️  loaded 115 scenarios from progressive_knn_medians.json


In [None]:
print(prog_stats)

{(64, 128, 4): {'accuracy_pct': 83.7247, 'median_time_s': 4.04}, (64, 128, 8): {'accuracy_pct': 85.2227, 'median_time_s': 4.07}, (64, 128, 16): {'accuracy_pct': 86.4372, 'median_time_s': 4.29}, (64, 128, 32): {'accuracy_pct': 87.247, 'median_time_s': 4.75}, (64, 128, 64): {'accuracy_pct': 87.8947, 'median_time_s': 5.475}, (64, 128, 128): {'accuracy_pct': 88.502, 'median_time_s': 6.51}, (64, 128, 256): {'accuracy_pct': 88.583, 'median_time_s': 8.105}, (64, 128, 512): {'accuracy_pct': 88.7449, 'median_time_s': 9.345}, (64, 128, 1024): {'accuracy_pct': 88.7449, 'median_time_s': 12.995}, (64, 256, 4): {'accuracy_pct': 84.8178, 'median_time_s': 4.04}, (64, 256, 8): {'accuracy_pct': 86.6397, 'median_time_s': 4.23}, (64, 256, 16): {'accuracy_pct': 88.4211, 'median_time_s': 4.51}, (64, 256, 32): {'accuracy_pct': 89.7571, 'median_time_s': 5.14}, (64, 256, 64): {'accuracy_pct': 90.5263, 'median_time_s': 6.19}, (64, 256, 128): {'accuracy_pct': 91.5385, 'median_time_s': 7.83}, (64, 256, 256): {'ac

In [None]:
for startD in dim_grid:
    for maxD in dim_grid:
        if startD >= maxD:         # strictly ascending
            continue
        for k0 in start_k_ls:
            key = (startD, maxD, k0)
            if key in prog_stats:  # already done or ingested
                continue

            times, accs = [], []
            for _ in range(10):
                res = progressive_knn_sklearn(
                    ds_embs     = dataset_embs,
                    qs_embs     = question_embs,
                    df_dataset  = df_dataset,
                    contexts    = contexts,
                    start_dim   = startD,
                    start_k     = k0,
                    max_dim     = maxD,
                    step_factor = 2,
                    step_k      = 2,
                    step_add    = None,
                    verbose     = False
                )
                times.append(res["t_total_s"])
                accs.append(res["accuracy_pct"])
                gc.collect()

            times_sorted = sorted(round(t, 4) for t in times)
            median_time  = round(float(np.median(times_sorted)), 4)
            acc_unique   = set(accs)
            if len(acc_unique) != 1:
                print(f"accuracy drift at [{startD:4d}->{maxD:<4d} | k={k0:4d}]")
            acc_final = round(accs[0], 4)

            # ---- identical print format to your original ---------
            print(f"[{startD:4d}->{maxD:<4d} | k={k0:4d}] "
                  f"times={times_sorted}  →  median={median_time:.4f}s  "
                  f"| acc={acc_final:.4f}%")

            prog_stats[key] = {
                "accuracy_pct" : acc_final,
                "median_time_s": median_time,
                # "times_s"      : times_sorted
            }
            save_json(prog_stats, JSON_NAME)   # overwrite after each scenario

# ==============================================================
print("\n=== MEDIAN TIMES (s) ===")
for (sD, mD, k0), st in sorted(prog_stats.items()):
    print(f"{sD:4}->{mD:<4}  k={k0:4} | "
          f"{st['accuracy_pct']:7.4f}%  | {st['median_time_s']:8.4f}s")

[ 256->1024 | k= 512] times=[36.38, 36.48, 36.84, 36.9, 36.92, 36.93, 37.03, 37.1, 37.12, 37.12]  →  median=36.9250s  | acc=94.4939%
✔  saved → /content/drive/MyDrive/Colab Notebooks/CMPE 295 RAG System/progressive_knn_medians.json
[ 256->1024 | k=1024] times=[48.37, 48.42, 48.65, 48.81, 48.87, 48.91, 49.04, 49.11, 49.22, 49.3]  →  median=48.8900s  | acc=94.4939%
✔  saved → /content/drive/MyDrive/Colab Notebooks/CMPE 295 RAG System/progressive_knn_medians.json
[ 256->2048 | k=   4] times=[9.79, 9.88, 9.93, 9.96, 9.96, 9.97, 9.99, 10.03, 10.05, 10.07]  →  median=9.9650s  | acc=94.1296%
✔  saved → /content/drive/MyDrive/Colab Notebooks/CMPE 295 RAG System/progressive_knn_medians.json
[ 256->2048 | k=   8] times=[10.26, 10.33, 10.39, 10.43, 10.44, 10.52, 10.57, 10.57, 10.58, 10.63]  →  median=10.4800s  | acc=94.7368%
✔  saved → /content/drive/MyDrive/Colab Notebooks/CMPE 295 RAG System/progressive_knn_medians.json
[ 256->2048 | k=  16] times=[11.49, 11.49, 11.51, 11.52, 11.59, 11.6, 11.62

#### Continue Run 3

In [None]:
if os.path.exists(JSON_PATH):
    with open(JSON_PATH) as f:
        prog_stats = {tuple(map(int, k.strip("()").split(","))): v
                      for k, v in json.load(f).items()}
    print(f"↩️  loaded {len(prog_stats)} scenarios from {JSON_NAME}")

↩️  loaded 160 scenarios from progressive_knn_medians.json


In [None]:
print(prog_stats)

{(64, 128, 4): {'accuracy_pct': 83.7247, 'median_time_s': 4.04}, (64, 128, 8): {'accuracy_pct': 85.2227, 'median_time_s': 4.07}, (64, 128, 16): {'accuracy_pct': 86.4372, 'median_time_s': 4.29}, (64, 128, 32): {'accuracy_pct': 87.247, 'median_time_s': 4.75}, (64, 128, 64): {'accuracy_pct': 87.8947, 'median_time_s': 5.475}, (64, 128, 128): {'accuracy_pct': 88.502, 'median_time_s': 6.51}, (64, 128, 256): {'accuracy_pct': 88.583, 'median_time_s': 8.105}, (64, 128, 512): {'accuracy_pct': 88.7449, 'median_time_s': 9.345}, (64, 128, 1024): {'accuracy_pct': 88.7449, 'median_time_s': 12.995}, (64, 256, 4): {'accuracy_pct': 84.8178, 'median_time_s': 4.04}, (64, 256, 8): {'accuracy_pct': 86.6397, 'median_time_s': 4.23}, (64, 256, 16): {'accuracy_pct': 88.4211, 'median_time_s': 4.51}, (64, 256, 32): {'accuracy_pct': 89.7571, 'median_time_s': 5.14}, (64, 256, 64): {'accuracy_pct': 90.5263, 'median_time_s': 6.19}, (64, 256, 128): {'accuracy_pct': 91.5385, 'median_time_s': 7.83}, (64, 256, 256): {'ac

In [None]:
for startD in dim_grid:
    for maxD in dim_grid:
        if startD >= maxD:         # strictly ascending
            continue
        for k0 in start_k_ls:
            key = (startD, maxD, k0)
            if key in prog_stats:  # already done or ingested
                continue

            times, accs = [], []
            for _ in range(10):
                res = progressive_knn_sklearn(
                    ds_embs     = dataset_embs,
                    qs_embs     = question_embs,
                    df_dataset  = df_dataset,
                    contexts    = contexts,
                    start_dim   = startD,
                    start_k     = k0,
                    max_dim     = maxD,
                    step_factor = 2,
                    step_k      = 2,
                    step_add    = None,
                    verbose     = False
                )
                times.append(res["t_total_s"])
                accs.append(res["accuracy_pct"])
                gc.collect()

            times_sorted = sorted(round(t, 4) for t in times)
            median_time  = round(float(np.median(times_sorted)), 4)
            acc_unique   = set(accs)
            if len(acc_unique) != 1:
                print(f"accuracy drift at [{startD:4d}->{maxD:<4d} | k={k0:4d}]")
            acc_final = round(accs[0], 4)

            # ---- identical print format to your original ---------
            print(f"[{startD:4d}->{maxD:<4d} | k={k0:4d}] "
                  f"times={times_sorted}  →  median={median_time:.4f}s  "
                  f"| acc={acc_final:.4f}%")

            prog_stats[key] = {
                "accuracy_pct" : acc_final,
                "median_time_s": median_time,
                # "times_s"      : times_sorted
            }
            save_json(prog_stats, JSON_NAME)   # overwrite after each scenario

# ==============================================================
print("\n=== MEDIAN TIMES (s) ===")
for (sD, mD, k0), st in sorted(prog_stats.items()):
    print(f"{sD:4}->{mD:<4}  k={k0:4} | "
          f"{st['accuracy_pct']:7.4f}%  | {st['median_time_s']:8.4f}s")

[ 512->3584 | k= 512] times=[94.48, 94.74, 94.87, 94.95, 95.19, 95.62, 95.99, 95.99, 96.26, 96.43]  →  median=95.4050s  | acc=95.0202%
✔  saved → /content/drive/MyDrive/Colab Notebooks/CMPE 295 RAG System/progressive_knn_medians.json
[ 512->3584 | k=1024] times=[130.8, 131.5, 132.29, 132.91, 134.23, 136.68, 137.54, 137.92, 139.25, 139.4]  →  median=135.4550s  | acc=95.0202%
✔  saved → /content/drive/MyDrive/Colab Notebooks/CMPE 295 RAG System/progressive_knn_medians.json
[1024->2048 | k=   4] times=[30.72, 30.93, 30.97, 31.09, 31.2, 31.24, 31.31, 31.38, 31.46, 31.76]  →  median=31.2200s  | acc=94.8583%
✔  saved → /content/drive/MyDrive/Colab Notebooks/CMPE 295 RAG System/progressive_knn_medians.json
[1024->2048 | k=   8] times=[30.87, 31.24, 31.82, 31.84, 31.88, 31.92, 31.96, 32.01, 32.14, 32.3]  →  median=31.9000s  | acc=94.8583%
✔  saved → /content/drive/MyDrive/Colab Notebooks/CMPE 295 RAG System/progressive_knn_medians.json
[1024->2048 | k=  16] times=[32.19, 32.33, 32.47, 32.54, 3

In [None]:
len(prog_stats)

189

# Plot

## Overall

In [None]:
original_trunc_knn = {}
progressive_trunc_knn = {}

In [None]:
def load_json(json_path):
  with open(json_path) as f:
    return {tuple(map(int, k.strip("()").split(","))): v
            for k, v in json.load(f).items()}

In [None]:
progress_trunc_knn_json  = "progressive_knn_medians.json"
original_trunc_knn_json = "truncate_first_knn_medians.json"
progress_trunc_knn_json_path  = f"{drive_root}/{progress_trunc_knn_json}"
original_trunc_knn_json_path = f"{drive_root}/{original_trunc_knn_json}"

original_trunc_knn = load_json(original_trunc_knn_json_path)
progressive_trunc_knn = load_json(progress_trunc_knn_json_path)

In [None]:
print(original_trunc_knn)

{(3584,): {'accuracy_pct': 95.0202, 'median_time_s': 99.3581}, (3072,): {'accuracy_pct': 94.9798, 'median_time_s': 87.5013}, (2048,): {'accuracy_pct': 94.8178, 'median_time_s': 57.4883}, (1024,): {'accuracy_pct': 94.4939, 'median_time_s': 29.6964}, (512,): {'accuracy_pct': 93.8057, 'median_time_s': 16.4421}, (504,): {'accuracy_pct': 93.8866, 'median_time_s': 15.7597}, (496,): {'accuracy_pct': 93.8057, 'median_time_s': 15.6233}, (488,): {'accuracy_pct': 93.7652, 'median_time_s': 15.3404}, (480,): {'accuracy_pct': 93.7247, 'median_time_s': 15.0536}, (472,): {'accuracy_pct': 93.7652, 'median_time_s': 14.7977}, (464,): {'accuracy_pct': 93.8057, 'median_time_s': 14.7158}, (456,): {'accuracy_pct': 93.7652, 'median_time_s': 14.5692}, (448,): {'accuracy_pct': 93.6842, 'median_time_s': 14.3969}, (440,): {'accuracy_pct': 93.5628, 'median_time_s': 14.2381}, (432,): {'accuracy_pct': 93.5223, 'median_time_s': 13.9428}, (424,): {'accuracy_pct': 93.6032, 'median_time_s': 14.0046}, (416,): {'accuracy_

In [None]:
print(progressive_trunc_knn)

{(64, 128, 4): {'accuracy_pct': 83.7247, 'median_time_s': 4.04}, (64, 128, 8): {'accuracy_pct': 85.2227, 'median_time_s': 4.07}, (64, 128, 16): {'accuracy_pct': 86.4372, 'median_time_s': 4.29}, (64, 128, 32): {'accuracy_pct': 87.247, 'median_time_s': 4.75}, (64, 128, 64): {'accuracy_pct': 87.8947, 'median_time_s': 5.475}, (64, 128, 128): {'accuracy_pct': 88.502, 'median_time_s': 6.51}, (64, 128, 256): {'accuracy_pct': 88.583, 'median_time_s': 8.105}, (64, 128, 512): {'accuracy_pct': 88.7449, 'median_time_s': 9.345}, (64, 128, 1024): {'accuracy_pct': 88.7449, 'median_time_s': 12.995}, (64, 256, 4): {'accuracy_pct': 84.8178, 'median_time_s': 4.04}, (64, 256, 8): {'accuracy_pct': 86.6397, 'median_time_s': 4.23}, (64, 256, 16): {'accuracy_pct': 88.4211, 'median_time_s': 4.51}, (64, 256, 32): {'accuracy_pct': 89.7571, 'median_time_s': 5.14}, (64, 256, 64): {'accuracy_pct': 90.5263, 'median_time_s': 6.19}, (64, 256, 128): {'accuracy_pct': 91.5385, 'median_time_s': 7.83}, (64, 256, 256): {'ac

In [None]:
orig_dims, orig_times, orig_accs = [], [], []
for (dim,), stats in original_trunc_knn.items():
    orig_dims.append(dim)
    orig_times.append(stats["median_time_s"])
    orig_accs.append(stats["accuracy_pct"])

# sort by dimension so the line is monotonic
orig_sorted = sorted(zip(orig_dims, orig_times, orig_accs))
orig_dims, orig_times, orig_accs = map(list, zip(*orig_sorted))

# ---------- 2)  flatten the progressive results ----------
prog_sdim, prog_mdim, prog_k0  = [], [], []
prog_times, prog_accs          = [], []

for (s_dim, m_dim, k0), stats in progressive_trunc_knn.items():
    prog_sdim.append(s_dim)
    prog_mdim.append(m_dim)
    prog_k0.append(k0)
    prog_times.append(stats["median_time_s"])
    prog_accs.append(stats["accuracy_pct"])

# ---------- 3)  build traces ----------
trace_baseline = go.Scatter(
    x = orig_times,
    y = orig_accs,
    mode   = "lines+markers",
    line   = dict(color="royalblue"),
    marker = dict(symbol="circle", size=9),
    name   = "Truncate‑KNN baseline",
    text   = [f"Dim: {d}<br>Acc: {acc:.4f}%<br>Time: {t:.4f}s"
              for d, acc, t in zip(orig_dims, orig_accs, orig_times)],
    hoverinfo = "text"
)

trace_prog = go.Scatter(
    x = prog_times,
    y = prog_accs,
    mode   = "markers",
    marker = dict(color="crimson", symbol="square", size=8, opacity=0.85),
    name   = "Progressive KNN",
    text   = [
        (f"Start dim: {s}<br>End dim: {m}<br>Start k: {k0}"
         f"<br>Acc: {acc:.4f}%<br>Time: {t:.4f}s")
        for s, m, k0, acc, t in zip(
            prog_sdim, prog_mdim, prog_k0, prog_accs, prog_times)
    ],
    hoverinfo = "text"
)

# ---------- 4)  figure ----------
layout = go.Layout(
    title = "Accuracy vs Search Time  (Truncate‑KNN baseline  vs  Progressive KNN)",
    xaxis = dict(title="Median search time per query (s)"),
    yaxis = dict(title="Top‑1 accuracy (%)", range=[min(min(orig_accs),min(prog_accs))-1, 100]),
    height = 700,
    template = "plotly_white",
    hovermode = "closest"
)

fig = go.Figure(data=[trace_baseline, trace_prog], layout=layout)
fig.show()

## In detailed

In [None]:
prog_marker_style = dict(symbol="square", size=8, opacity=0.85, color="crimson")
baseline_style    = dict(symbol="circle", size=9, color="royalblue")


# baseline vectors
orig_dims, orig_times, orig_accs = [], [], []
for (dim,), stats in original_trunc_knn.items():
    orig_dims.append(dim)
    orig_times.append(stats["median_time_s"])
    orig_accs.append(stats["accuracy_pct"])
orig_sorted = sorted(zip(orig_dims, orig_times, orig_accs))
orig_dims, orig_times, orig_accs = map(list, zip(*orig_sorted))

baseline_trace = go.Scatter(
    x     = orig_times,
    y     = orig_accs,
    mode  = "lines+markers",
    marker= baseline_style,
    line  = dict(color=baseline_style["color"]),
    name  = "Truncate‑KNN baseline",
    hoverinfo = "text",
    text  = [f"Dim {d}<br>Acc {a:.4f}%<br>{t:.4f}s"
             for d, a, t in zip(orig_dims, orig_accs, orig_times)]
)

# helper that makes one figure
def make_figure(subset_keys, title_suffix):
    prog_times, prog_accs, prog_hover = [], [], []
    for key in subset_keys:
        s_dim, m_dim, k0 = key
        stats  = progressive_trunc_knn[key]
        prog_times.append(stats["median_time_s"])
        prog_accs.append(stats["accuracy_pct"])
        prog_hover.append(
            f"Start { s_dim } → End { m_dim }<br>k₀={k0}"
            f"<br>Acc {stats['accuracy_pct']:.4f}%"
            f"<br>{stats['median_time_s']:.4f}s"
        )

    prog_trace = go.Scatter(
        x = prog_times, y = prog_accs,
        mode   = "markers",
        marker = prog_marker_style,
        name   = "Progressive KNN",
        text   = prog_hover, hoverinfo="text"
    )

    layout = go.Layout(
        title = f"Accuracy vs Time {title_suffix}",
        xaxis = dict(title="Median search time (s)"),
        yaxis = dict(title="Top‑1 accuracy (%)",
                     range=[min(min(orig_accs),min(prog_accs))-1, 100]),
        template="plotly_white", height=650, hovermode="closest"
    )
    fig = go.Figure(data=[baseline_trace, prog_trace], layout=layout)
    fig.show()


### By Dim range

In [None]:
# build unique range list
ranges = sorted({(s_dim, m_dim) for (s_dim, m_dim, _) in progressive_trunc_knn})
for s_dim, m_dim in ranges:
    keys = [k for k in progressive_trunc_knn if k[0]==s_dim and k[1]==m_dim]
    make_figure(keys, f"(range { s_dim } → { m_dim })")

### By start_k

In [None]:
k_vals = sorted({k0 for *_, k0 in progressive_trunc_knn})
for k0 in k_vals:
    keys = [k for k in progressive_trunc_knn if k[2]==k0]
    make_figure(keys, f"(start‑k = {k0})")

# Print

In [None]:
for k, v in original_trunc_knn.items():
    print(f"{k} | {v['accuracy_pct']:7.4f}%  | {v['median_time_s']:8.4f}s")

(3584,) | 95.0202%  |  99.3581s
(3072,) | 94.9798%  |  87.5013s
(2048,) | 94.8178%  |  57.4883s
(1024,) | 94.4939%  |  29.6964s
(512,) | 93.8057%  |  16.4421s
(504,) | 93.8866%  |  15.7597s
(496,) | 93.8057%  |  15.6233s
(488,) | 93.7652%  |  15.3404s
(480,) | 93.7247%  |  15.0536s
(472,) | 93.7652%  |  14.7977s
(464,) | 93.8057%  |  14.7158s
(456,) | 93.7652%  |  14.5692s
(448,) | 93.6842%  |  14.3969s
(440,) | 93.5628%  |  14.2381s
(432,) | 93.5223%  |  13.9428s
(424,) | 93.6032%  |  14.0046s
(416,) | 93.6032%  |  13.6727s
(408,) | 93.5628%  |  13.5728s
(400,) | 93.6032%  |  13.3905s
(392,) | 93.5223%  |  13.0533s
(384,) | 93.4008%  |  12.8413s
(376,) | 93.2794%  |  12.4153s
(368,) | 93.2794%  |  12.0985s
(360,) | 93.2389%  |  11.9688s
(352,) | 93.1579%  |  11.7205s
(344,) | 92.9555%  |  11.5171s
(336,) | 92.8745%  |  11.3801s
(328,) | 92.8745%  |  11.1568s
(320,) | 92.7126%  |  10.8814s
(312,) | 92.5911%  |  10.7416s
(304,) | 92.3887%  |  10.5377s
(296,) | 92.5506%  |  10.4184s
(288

In [None]:
for k, v in progressive_trunc_knn.items():
    print(f"{k} | {v['accuracy_pct']:7.4f}%  | {v['median_time_s']:8.4f}s")

(64, 128, 4) | 83.7247%  |   4.0400s
(64, 128, 8) | 85.2227%  |   4.0700s
(64, 128, 16) | 86.4372%  |   4.2900s
(64, 128, 32) | 87.2470%  |   4.7500s
(64, 128, 64) | 87.8947%  |   5.4750s
(64, 128, 128) | 88.5020%  |   6.5100s
(64, 128, 256) | 88.5830%  |   8.1050s
(64, 128, 512) | 88.7449%  |   9.3450s
(64, 128, 1024) | 88.7449%  |  12.9950s
(64, 256, 4) | 84.8178%  |   4.0400s
(64, 256, 8) | 86.6397%  |   4.2300s
(64, 256, 16) | 88.4211%  |   4.5100s
(64, 256, 32) | 89.7571%  |   5.1400s
(64, 256, 64) | 90.5263%  |   6.1900s
(64, 256, 128) | 91.5385%  |   7.8300s
(64, 256, 256) | 91.9838%  |  10.6550s
(64, 256, 512) | 92.4696%  |  14.5950s
(64, 256, 1024) | 92.6721%  |  19.9400s
(64, 512, 4) | 84.8178%  |   4.1100s
(64, 512, 8) | 86.8016%  |   4.3100s
(64, 512, 16) | 88.7854%  |   4.7200s
(64, 512, 32) | 90.1215%  |   5.6050s
(64, 512, 64) | 91.0526%  |   7.0500s
(64, 512, 128) | 92.1457%  |   9.3600s
(64, 512, 256) | 92.7126%  |  13.2100s
(64, 512, 512) | 93.3198%  |  19.0400s
(64, 

In [None]:
print(f"{'K':<20} {'Accuracy %':<10} {'Median Time (s)':<15}")

for k, v in original_trunc_knn.items():
    print(f"{str(k):<20} {v['accuracy_pct']:.4f}%    {v['median_time_s']:8.4f}s")


K                    Accuracy % Median Time (s)
(3584,)              95.0202%     99.3581s
(3072,)              94.9798%     87.5013s
(2048,)              94.8178%     57.4883s
(1024,)              94.4939%     29.6964s
(512,)               93.8057%     16.4421s
(504,)               93.8866%     15.7597s
(496,)               93.8057%     15.6233s
(488,)               93.7652%     15.3404s
(480,)               93.7247%     15.0536s
(472,)               93.7652%     14.7977s
(464,)               93.8057%     14.7158s
(456,)               93.7652%     14.5692s
(448,)               93.6842%     14.3969s
(440,)               93.5628%     14.2381s
(432,)               93.5223%     13.9428s
(424,)               93.6032%     14.0046s
(416,)               93.6032%     13.6727s
(408,)               93.5628%     13.5728s
(400,)               93.6032%     13.3905s
(392,)               93.5223%     13.0533s
(384,)               93.4008%     12.8413s
(376,)               93.2794%     12.4153s
(368,)

In [None]:
print(f"{'Start dim, Max dim, K':<20} {'Accuracy %':<10} {'Median Time (s)':<15}")

for k, v in progressive_trunc_knn.items():
    print(f"{str(k):<20} {v['accuracy_pct']:7.4f}%    {v['median_time_s']:8.4f}s")


Start dim, Max dim, K Accuracy % Median Time (s)
(64, 128, 4)         83.7247%      4.0400s
(64, 128, 8)         85.2227%      4.0700s
(64, 128, 16)        86.4372%      4.2900s
(64, 128, 32)        87.2470%      4.7500s
(64, 128, 64)        87.8947%      5.4750s
(64, 128, 128)       88.5020%      6.5100s
(64, 128, 256)       88.5830%      8.1050s
(64, 128, 512)       88.7449%      9.3450s
(64, 128, 1024)      88.7449%     12.9950s
(64, 256, 4)         84.8178%      4.0400s
(64, 256, 8)         86.6397%      4.2300s
(64, 256, 16)        88.4211%      4.5100s
(64, 256, 32)        89.7571%      5.1400s
(64, 256, 64)        90.5263%      6.1900s
(64, 256, 128)       91.5385%      7.8300s
(64, 256, 256)       91.9838%     10.6550s
(64, 256, 512)       92.4696%     14.5950s
(64, 256, 1024)      92.6721%     19.9400s
(64, 512, 4)         84.8178%      4.1100s
(64, 512, 8)         86.8016%      4.3100s
(64, 512, 16)        88.7854%      4.7200s
(64, 512, 32)        90.1215%      5.6050s
(64, 

In [None]:
import csv

with open(f"{drive_root}/static_truncate_knn.csv", "w", newline="") as csvfile:
    writer = csv.writer(csvfile)
    writer.writerow(["Key", "Accuracy (%)", "Median Time (s)"])

    for k, v in original_trunc_knn.items():
        key_str = ", ".join(map(str, k))
        writer.writerow([key_str, f"{v['accuracy_pct']:.4f}", f"{v['median_time_s']:.4f}"])


In [None]:
with open(f"{drive_root}/progressive_truncate_knn.csv", "w", newline="") as csvfile:
    writer = csv.writer(csvfile)
    writer.writerow(["Starting Dim", "Max Dim", "Top K", "Accuracy (%)", "Median Time (s)"])

    for k, v in progressive_trunc_knn.items():
        # Unpack the tuple into parts
        starting_dim, max_dim, top_k = k

        accuracy = f"{v['accuracy_pct']:.4f}"
        time = f"{v['median_time_s']:.4f}"

        # Write them separately
        writer.writerow([starting_dim, max_dim, top_k, accuracy, time])
