## setup

In [8]:
import redis
import json
from multiprocessing import Process, cpu_count
import json
import pandas as pd
import numpy as np

from typing import List

In [9]:
dataset_path = "../data/datasets/careful"
properties_path = "../data/outputs/careful.json"
metric = "pitch_histogram"

In [10]:
# FT.CREATE idx:table ON JSON PREFIX 1 cmp: SCHEMA $.sim AS sim NUMERIC $.row_file AS row_file TEXT $.col_file AS col_file TEXT $.metric AS metric TEXT

In [12]:
# redis setup
redis_url = "redis://localhost:6379"
r = redis.Redis(redis_url)


# load from fs
properties = {}
with open(properties_path, "r") as f:
    properties = json.load(f)

names = list(properties.keys())
names.sort()

num_processes = cpu_count()
rows_per_process = len(names) // num_processes  # type: ignore
extra_rows = len(names) % num_processes  # type: ignore
print(f"{len(names)} & {num_processes} -> {rows_per_process} + {extra_rows}")

3868 & 12 -> 322 + 4


## general tests

In [None]:
name1 = "20240121-70-06_0096-0104.mid"  # names[0]
name2 = "20240227-76-05_0128-0136.mid"  # names[-1]
print(f"{name1} {name2} {metric}")
r.json().get(f"cmp:{name1}:{name2}:{metric}")

In [None]:
def scan_keys(r, pattern):
    cursor = 0
    keys = []
    while True:
        cursor, new_keys = r.scan(cursor, match=pattern)
        keys.extend(new_keys)
        if cursor == 0:
            break
    return keys


# Pattern to match
pattern = "20231220-80-01_0000-0008.mid:*:pitch_histogram"

# Get all keys matching the pattern
# matching_keys = scan_keys(r, pattern)
# print(f"Keys matching pattern '{pattern}': {matching_keys}")

In [None]:
def process_json_keys(redis_conn):
    cursor = "0"
    while cursor != 0:
        cursor, keys = redis_conn.scan(cursor=cursor, count=1000)
        for key in keys:
            key_type = redis_conn.execute_command("TYPE", key)
            if key_type == b"ReJSON-RL":
                value = redis_conn.json().get(key)

                if value:
                    row_file, col_file, metric = str(key).split(":")

                    value["row_file"] = row_file[2:]
                    value["col_file"] = col_file
                    value["metric"] = metric

                    # print(f"Key: {key}, Data: {value}")
                    r.json().set(key, "$", value)

            else:
                # Ignore non-JSON objects
                continue
        print(f"finished section {cursor}")


# Call the function
process_json_keys(r)
print("DONE")

In [None]:
query = 'FT.SEARCH idx:sim "20231220-80-01_0000-0008.mid:*:pitch_histogram" SORTBY sim DESC LIMIT 0 10'
result = r.execute_command(query)

# Extract the key with the largest 'sim' value
if len(result) > 1:
    key_with_max_sim = result[1]
    print(f"Key with the largest 'sim' value: {key_with_max_sim}")
else:
    print("No matching keys found.")

## build df

In [14]:
def build_similarity_dataframe(
    redis_url: str,
    all_files: List[str],
    indices: slice,
    subr_idx=0,
    batch_size=1000,
    metric="pitch_histogram",
):
    r = redis.Redis(redis_url)
    df = pd.DataFrame(index=all_files[indices], columns=all_files, dtype=np.float16)

    print(
        f"[SUBR{subr_idx:02d}] populating a {indices.stop - indices.start} by {len(all_files)} df"
    )

    # process keys in batches
    for i in range(0, indices.stop - indices.start, batch_size):
        row_batch = all_files[indices][i : i + batch_size]

        # use Redis pipeline to batch process the keys
        with r.pipeline() as pipe:
            for row_file in row_batch:
                keys = [f"cmp:{row_file}:{col_file}:{metric}" for col_file in all_files]
                for key in keys:
                    pipe.json().get(key, "$.sim")
            results = pipe.execute()

        # parse the results and update the DataFrame
        for row_file in row_batch:
            keys = [f"cmp:{row_file}:{col_file}:{metric}" for col_file in all_files]
            for key, value in zip(keys, results):
                if value:
                    col_file = key.split(":")[1]
                    sim = round(
                        value["sim"], 5
                    )  # limit precision to 5 decimal places (fp16 is 3 i think)
                    if col_file in all_files:
                        df.at[row_file, col_file] = sim
                else:
                    print(f"[SUBR{subr_idx:02d}] ERROR: no value found at {key}")
    return i, df

In [15]:
processes = []
start_index = 0
for i in range(num_processes):
    end_index = start_index + rows_per_process + (1 if i < extra_rows else 0)
    index_range = slice(start_index, end_index)
    process = Process(
        target=build_similarity_dataframe,
        args=(redis_url, names, index_range, i),
    )
    processes.append(process)
    process.start()
    start_index = end_index

batch_dfs = []
for process in processes:
    subr_i, new_df = process.join()
    batch_dfs.append(new_df)
    print(f"[MAIN]   subroutine {subr_i} complete")

[SUBR00] populating an 323 by 3868 df
[SUBR01] populating an 323 by 3868 df[SUBR02] populating an 323 by 3868 df

[SUBR03] populating an 323 by 3868 df
[SUBR04] populating an 322 by 3868 df
[SUBR06] populating an 322 by 3868 df[SUBR05] populating an 322 by 3868 df

[SUBR07] populating an 322 by 3868 df
[SUBR08] populating an 322 by 3868 df
[SUBR09] populating an 322 by 3868 df
[SUBR10] populating an 322 by 3868 df
[SUBR11] populating an 322 by 3868 df


KeyboardInterrupt: 

Process Process-1:
Traceback (most recent call last):
  File "/home/finlay/disklavier/.venv/lib/python3.12/site-packages/redis/connection.py", line 276, in connect
    sock = self.retry.call_with_retry(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/finlay/disklavier/.venv/lib/python3.12/site-packages/redis/retry.py", line 46, in call_with_retry
    return do()
           ^^^^
  File "/home/finlay/disklavier/.venv/lib/python3.12/site-packages/redis/connection.py", line 277, in <lambda>
    lambda: self._connect(), lambda error: self.disconnect(error)
            ^^^^^^^^^^^^^^^
  File "/home/finlay/disklavier/.venv/lib/python3.12/site-packages/redis/connection.py", line 607, in _connect
    for res in socket.getaddrinfo(
               ^^^^^^^^^^^^^^^^^^^
  File "/home/finlay/miniconda/envs/py312/lib/python3.12/socket.py", line 963, in getaddrinfo
    for res in _socket.getaddrinfo(host, port, family, type, proto, flags):
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^