In [1]:
import os
import h5py
import torch
import numpy as np
import pandas as pd
from scipy.spatial.distance import cosine
from itertools import product
from rich.progress import track

In [2]:
rng = np.random.default_rng(0)
BUILD_DATASET = False

In [15]:
p_h5 = "data/clamp_embeddings.h5"
if BUILD_DATASET:
    import redis
    r = redis.Redis(host='localhost', port=6379, db=0, decode_responses=True)
    n_rows = 0
    keys: list[str] = r.keys('files:*') # type: ignore
    keys.sort()
    embeddings = torch.empty((len(keys), 768), dtype=torch.float32, device='cuda')
    for i, key in track(enumerate(keys), "downloading embeddings"): # type: ignore
        embeddings[i] = torch.Tensor(r.json().get(key, "$.clamp")[0]) # type: ignore
    print(embeddings.shape)
else:
    # verify hdf5 file
    with h5py.File(p_h5, "r") as hdf:
        print(f"{len(hdf.keys())} rows in file")
        embeddings = pd.DataFrame(index=list(hdf.keys()), columns=["embedding"])
        # Iterate over all items in the HDF5 file
        nk = 0
        for key in hdf.keys():
            # Get the dataset associated with the key
            dataset = hdf[key]
            # Print the key and its corresponding value
            if key in embeddings.index:  # Check if the key exists
                embeddings.loc[key, "embedding"] = hdf[key][:]
            else:
                print(f"Key '{key}' not found in embeddings.")
        print(f"loaded embeddings: {embeddings.shape}")

545952 rows in file
loaded embeddings: (545952, 1)


In [17]:
print(f"Size of embeddings in MB: {embeddings.memory_usage(deep=True).sum() / (1024 ** 2):.2f} MB")


Size of embeddings in MB: 1720.24 MB


In [41]:
# os.remove(p_h5)
with h5py.File(p_h5, 'w') as hdf:
    for i, (k, e) in enumerate(zip(keys, embeddings)):
        hdf.create_dataset(k.split(':')[-1], data=e.cpu().numpy())
        keys[i] = keys[i].split(':')[-1]
print(f"wrote {len(keys)} rows to '{p_h5}'\n", keys[:3])

wrote 545952 rows to 'data/clamp_embeddings.h5' ['20231220-080-01_0000-0005_t00s00', '20231220-080-01_0000-0005_t00s01', '20231220-080-01_0000-0005_t00s02']


In [42]:
def get_embedding(filename: str) -> torch.Tensor:
	return embeddings[keys.index(filename)]

In [43]:
# load prompt dataset
p_test = "../data/datasets/test/dataset samples"
transformations = [f"t{t:02d}s{s:02d}" for t, s in product(range(12), range(8))]
q_files = [f"{f[:-4]}_{rng.choice(transformations)}" for f in os.listdir(p_test)]
q_files[:3]

['20240511-088-02_0119-0125_t11s07',
 '20240305-050-04_0076-0086_t02s03',
 '20231220-080-04_0029-0035_t11s02']

In [48]:
# load prompt embeddings
q_embeddings = [get_embedding(key) for key in q_files]
q_embeddings[0].shape

torch.Size([768])

In [77]:
# find best and worst n matches for prompts
n_matches = 5
matches = {}
for i, query in enumerate(q_embeddings):
    similarities = torch.matmul(embeddings, query) / (torch.norm(embeddings, dim=1) * torch.norm(query))
    sim_cpu = similarities.cpu().numpy()
    matches[q_files[i]] = {}

    # no regard for different segments
    matches[q_files[i]]['simple'] = []
    for j in np.argsort(-sim_cpu)[:n_matches]: # indices of best matches
        matches[q_files[i]]['simple'].append([keys[j], float(sim_cpu[j])])
    for j in np.argsort(-sim_cpu)[-n_matches:]: # indices of worst matches
        matches[q_files[i]]['simple'].append([keys[j], float(sim_cpu[j])])

    # force different segments and unique each time
    matches[q_files[i]]['cplx'] = []
    q_track = q_files[i].split('_')[0]
    n_found = 0
    for k in np.argsort(-sim_cpu):
        match_track = keys[k].split('_')[0]
        added_matches = [m[0].split('_')[0] for m in matches[q_files[i]]['cplx']]
        if match_track == q_track or match_track in added_matches:
            continue
        matches[q_files[i]]['cplx'].append([keys[k], float(sim_cpu[k])])
        n_found += 1
        if n_found >= n_matches:
            break
    n_found = 0
    for k in reversed(np.argsort(-sim_cpu)):
        match_track = keys[k].split('_')[0]
        added_matches = [m[0].split('_')[0] for m in matches[q_files[i]]['cplx']]
        if match_track == q_track or match_track in added_matches:
            continue
        matches[q_files[i]]['cplx'].append([keys[k], float(sim_cpu[k])])
        n_found += 1
        if n_found >= n_matches:
            break

    # reverse order of worst complex matches due to search direction
    matches[q_files[i]]['cplx'][-n_matches:] = reversed(matches[q_files[i]]['cplx'][-n_matches:])
    
print(q_files[0])
matches[q_files[0]]

20240511-088-02_0119-0125_t11s07


{'simple': [['20240511-088-02_0119-0125_t11s07', 1.0000001192092896],
  ['20240511-088-02_0119-0125_t08s07', 0.9173212647438049],
  ['20240511-088-03_0021-0027_t06s06', 0.9104568958282471],
  ['20240511-088-03_0021-0027_t03s05', 0.9012600779533386],
  ['20240511-088-03_0021-0027_t09s04', 0.8920480012893677],
  ['20231227-080-03_0233-0239_t11s03', 0.31363505125045776],
  ['20240123-070-07_0425-0431_t06s03', 0.31282979249954224],
  ['20240123-070-07_0438-0445_t06s00', 0.3121066391468048],
  ['20240123-070-07_0438-0445_t00s04', 0.30916541814804077],
  ['20240123-070-07_0438-0445_t08s05', 0.2884775698184967]],
 'cplx': [['20240511-088-03_0021-0027_t06s06', 0.9104568958282471],
  ['20240124-064-02_0449-0457_t10s00', 0.8811522126197815],
  ['20240429-068-03_0070-0077_t10s04', 0.8779870271682739],
  ['20240312-080-05_0035-0041_t00s05', 0.8743062019348145],
  ['20240123-070-03_0774-0781_t08s03', 0.8720983266830444],
  ['20240312-080-03_0005-0011_t06s00', 0.3329693675041199],
  ['20240227-076-0

In [78]:
import json

with open('data/matches.json', 'w') as f:
    json.dump(matches, f)

In [71]:
i = q_files.index("20240511-088-02_0119-0125_t11s07")
query = q_embeddings[i]
similarities = torch.matmul(embeddings, query) / (torch.norm(embeddings, dim=1) * torch.norm(query))
sim_cpu = similarities.cpu().numpy()
[keys[k]  for k in np.argsort(-sim_cpu)[:20]]

['20240511-088-02_0119-0125_t11s07',
 '20240511-088-02_0119-0125_t08s07',
 '20240511-088-03_0021-0027_t06s06',
 '20240511-088-03_0021-0027_t03s05',
 '20240511-088-03_0021-0027_t09s04',
 '20240511-088-03_0065-0070_t01s06',
 '20240511-088-02_0076-0081_t09s07',
 '20240124-064-02_0449-0457_t10s00',
 '20240429-068-03_0070-0077_t10s04',
 '20240511-088-03_0000-0005_t00s05',
 '20240511-088-02_0119-0125_t06s01',
 '20240429-068-03_0197-0204_t07s01',
 '20240124-064-02_0607-0614_t09s07',
 '20240124-064-02_0637-0644_t08s03',
 '20240511-088-02_0119-0125_t08s01',
 '20240312-080-05_0035-0041_t00s05',
 '20240312-080-05_0041-0047_t01s05',
 '20240123-070-03_0774-0781_t08s03',
 '20240124-064-02_0569-0577_t07s04',
 '20240124-064-02_0622-0629_t06s04']

In [76]:
with open('data/matches.json', 'r') as f:
    json_matches = json.load(f)
index = 0
for i, (q, ms) in enumerate(json_matches.items()):
    if i == index:
        print(f"{i}\t{q}")
        for mode, matches in ms.items():
            print(f"\t{mode}")
            for f, s in matches:
                print(f"\t\t{f}\t{s}")

0	20240511-088-02_0119-0125_t11s07
	simple
		20240511-088-02_0119-0125_t11s07	1.0000001192092896
		20240511-088-02_0119-0125_t08s07	0.9173212647438049
		20240511-088-03_0021-0027_t06s06	0.9104568958282471
		20240123-070-07_0438-0445_t06s00	0.3121066391468048
		20240123-070-07_0438-0445_t00s04	0.30916541814804077
		20240123-070-07_0438-0445_t08s05	0.2884775698184967
	cplx
		20240511-088-03_0021-0027_t06s06	0.9104568958282471
		20240511-088-03_0065-0070_t01s06	0.8881725072860718
		20240511-088-02_0076-0081_t09s07	0.8876845836639404
		20231227-080-03_0233-0239_t11s03	0.31363505125045776
		20240123-070-07_0425-0431_t06s03	0.31282979249954224
		20240123-070-07_0438-0445_t08s05	0.2884775698184967
