In [None]:
from collections import defaultdict
import os
import csv

csv_dir = "../log/ms-marco/vectors-colbert/k10_s1K_v137K/maxsim"
# csv_dir = "../log/ms-marco/vectors-colbert/k10_s1K_v63K"

data_dict = defaultdict(list)

for csv_name in os.listdir(csv_dir):
    if not csv_name.endswith(".csv"):
        continue

    algm, _, time = csv_name[:-4].partition("-query-")
    filepath = os.path.join(csv_dir, csv_name)

    with open(filepath, newline="", encoding="utf-8") as f:
        reader = csv.reader(f)
        rows = list(reader)

    if not rows:
        continue

    header, data_rows = rows[0], rows[1:]
    if not data_rows:
        continue

    records = {h: [] for h in header}
    for row in data_rows:
        for h, v in zip(header, row):
            records[h].append(int(v))

    data_dict[algm].append((time, records))

for algm in data_dict:
    data_dict[algm].sort(key=lambda x: x[0], reverse=True)

data_dict

In [None]:
import matplotlib.pyplot as plt


plot_styles = {
    "brute_force": {"color": "black", "marker": "^", "label": "Brute Force"},
    "hnswlib": {"color": "orange", "marker": "s", "label": "HNSWLib"},
    "ivfpq": {"color": "green", "marker": "D", "label": "IVFPQ"},
    "hnsw": {"color": "red", "marker": "o", "label": "HNSW"},
    "set_hnsw": {"color": "blue", "marker": ">", "label": "Set HNSW"},
    "muvera": {"color": "brown", "marker": "x", "label": "Muvera"},
    "prune_hnsw": [
        {"color": "cyan", "marker": "p", "label": "Prune HNSW"},
    ],
}


def get_records(algm, index):
    if algm not in data_dict:
        return None
    if isinstance(index, int):
        return data_dict[algm][index][1]
    elif isinstance(index, str):
        for time, records in data_dict[algm]:
            if time == index:
                return records


def extract_xy(records):
    recall = [h / t for h, t in zip(records["hit"], records["total"])]
    qps = [q / (t / 1e6) for q, t in zip(records["q_num"], records["time"])]
    return recall, qps


plt.figure(figsize=(7, 5))

for algm, styles in plot_styles.items():
    if not isinstance(styles, list):
        styles = [styles]
    for style in styles:
        records = get_records(algm, style.get("time", 0))
        if records is None:
            continue
        recall, qps = extract_xy(records)
        plt.plot(recall, qps, marker=style["marker"], color=style["color"], label=style["label"], linewidth=1.5)

plt.xlabel("Recall")
plt.ylabel("QPS (queries per second)")
plt.yscale("log")
plt.title("Recall vs QPS (log scale)")
plt.legend()
plt.grid(True, linestyle="--", alpha=0.6)
plt.tight_layout()
plt.show()

In [None]:
import matplotlib.pyplot as plt


def extract_dist_metrics(records):
    required = {"q_num", "hnsw_dist_comps", "rerank_dist_comps", "tot_dist_comps", "hit", "total"}
    if not required.issubset(records):
        return None
    q_num = records["q_num"]
    recall = [h / t if t else 0 for h, t in zip(records["hit"], records["total"])]

    def avg(field):
        return [v / q if q else 0 for v, q in zip(records[field], q_num)]

    return recall, avg("hnsw_dist_comps"), avg("rerank_dist_comps"), avg("tot_dist_comps")


fig, ax = plt.subplots(figsize=(7, 5))

for algm, styles in plot_styles.items():
    if not isinstance(styles, list):
        styles = [styles]
    for style in styles:
        records = get_records(algm, style.get("time", 0))
        if records is None:
            continue
        extracted = extract_dist_metrics(records)
        if extracted is None:
            continue

        recall, hnsw_d, rerank_d, tot_d = extracted
        # ax.plot(recall, hnsw_d, marker=style["marker"], color=style["color"], linestyle="-", label=f"{style['label']}")
        ax.plot(recall, rerank_d, marker=style["marker"], color=style["color"], linestyle="--")
        # ax.plot(recall, tot_d, marker=style["marker"], color=style["color"], linestyle=":")

ax.set_xlabel("Recall")
ax.set_ylabel("Avg Distance Computations")
ax.set_yscale("log")
ax.legend(fontsize=9)
ax.grid(True, linestyle="--", alpha=0.5)
plt.title("Distance Computations vs Recall (log scale)")
plt.tight_layout()
plt.show()