In [1]:
import json
import matplotlib.pyplot as plt
import prettytable

"""Parses the output from go test -bench=BenchmarkQuantizationRecall -run=^$ ./adapters/repos/db/vector/compressionhelpers/ -count=1 -benchtime=1x -json"""
def load_data(file_name):
    rows = []
    with open(file_name, 'r') as file:
        for line in file:
            json_line = json.loads(line)
            if "Output" in json_line and "\t" in json_line["Output"]:
                output = json_line["Output"]
                tabs = output.split("\t")
                if len(tabs) > 5:
                    (dataset, description) = tabs[0].split("|")[1:3]
                    rows.append({
                        "dataset": dataset,
                        "algorithm": description[:2],
                        "bits": int(float(tabs[3][:-4])),
                        "description": description,
                        "recall100@100": float(tabs[6][:-10]),
                        "recall100@500": float(tabs[7][:-11]),
                    })
    return rows

def extract_column(dataset, property, rows):
    l = []
    for r in rows:
        if r["dataset"] == dataset:
            l.append(r[property])
    return l

def plot(file_name, dataset):
    rows = load_data(file_name)
    bits = extract_column(dataset, "bits", rows)
    algorithm = extract_column(dataset, "algorithm", rows)
    description = extract_column(dataset, "description", rows)
    rec100at100 = extract_column(dataset, "recall100@100", rows)
    colormap = {"BQ": "green", "PQ": "blue", "SQ": "purple", "RQ": "orange"}
    colors = [colormap[a] for a in algorithm]
    
    fig, ax = plt.subplots()
    scatter = ax.scatter(bits, rec100at100, c=colors, label=algorithm)
    ax.grid(True)
    for i, desc in enumerate(description):
        ax.annotate(desc[:2], (bits[i] + 0.2, rec100at100[i]))
    ax.set_ylabel("recall100@100")
    ax.set_xlabel("bits/dimension")
    # Add legend with colors

    ax.set_title(dataset)
    plt.show()

# Algorithm / Dataset -> recall
def ascii_table(file_name, datasets):
    rows = load_data(file_name)

    field_names = ["Algorithm", "Bits"] + datasets

    columns = []
    columns.append(extract_column(datasets[0], "description", rows))
    columns.append(extract_column(datasets[0], "bits", rows))
    for ds in datasets:
        rec100 = extract_column(ds, "recall100@100", rows)
        rec500 = extract_column(ds, "recall100@500", rows)
        rec = [f"{r1:.3f} ({r2:.3f})" for (r1, r2) in zip(rec100, rec500)]
        columns.append(rec)

    table = prettytable.PrettyTable()
    table.field_names = field_names
    
    num_rows = len(columns[0])
    prev_bits = 1
    for i in range(num_rows):
        row = [c[i] for c in columns]
        bits = row[1]
        if bits != prev_bits:
            table.add_divider()
            prev_bits = bits 
        table.add_row(row)
    
    print(table)

datasets = ["sift-128-euclidean", "glove-200-angular"]
ascii_table("sift_gist_100k_q250.txt", datasets)



+-----------+------+--------------------+-------------------+
| Algorithm | Bits | sift-128-euclidean | glove-200-angular |
+-----------+------+--------------------+-------------------+
|     BQ    |  1   |   0.001 (0.004)    |   0.164 (0.340)   |
|  PQ(8,8)  |  1   |   0.715 (0.992)    |   0.370 (0.668)   |
|  PQ(4,4)  |  1   |   0.599 (0.953)    |   0.322 (0.598)   |
|     RQ    |  1   |   0.379 (0.796)    |   0.383 (0.728)   |
+-----------+------+--------------------+-------------------+
|  PQ(8,4)  |  2   |   0.819 (1.000)    |   0.644 (0.942)   |
|  PQ(4,2)  |  2   |   0.730 (0.995)    |   0.603 (0.912)   |
|     RQ    |  2   |   0.640 (0.979)    |   0.656 (0.968)   |
+-----------+------+--------------------+-------------------+
|  PQ(8,2)  |  4   |   0.918 (1.000)    |   0.872 (0.999)   |
|  PQ(4,1)  |  4   |   0.868 (1.000)    |   0.826 (0.993)   |
|     RQ    |  4   |   0.895 (1.000)    |   0.898 (1.000)   |
+-----------+------+--------------------+-------------------+
|     SQ