In [1]:
import os
import h5py
import numpy as np
import pandas as pd
from scipy.spatial.distance import cosine
from itertools import product

In [2]:
rng = np.random.default_rng(0)

In [4]:
# verify hdf5 file
p_h5 = "files/clamp_embeddings.h5"
with h5py.File(p_h5, "r") as hdf:
    print(f"{len(hdf.keys())} rows in file")
    # Iterate over all items in the HDF5 file
    nk = 0
    for key in hdf.keys():
        # Get the dataset associated with the key
        dataset = hdf[key]
        # Print the key and its corresponding value
        print(f"Key: '{key}', Value: {dataset}")
        break
embedding_dataset = h5py.File(p_h5, "r")
embedding_dataset

545952 rows in file
Key: '20231220-080-01_0000-0005_t00s00', Value: <HDF5 dataset "20231220-080-01_0000-0005_t00s00": shape (768,), type "<f8">


<HDF5 file "clamp_embeddings.h5" (mode r)>

In [6]:
# load prompt dataset
p_test = "../data/datasets/test/dataset samples"
transformations = [f"t{t:02d}s{s:02d}" for t, s in product(range(12), range(8))]
q_files = [f"{f[:-4]}_{rng.choice(transformations)}" for f in os.listdir(p_test)]
q_files[:3]

['20240122-055-03_0052-0060_t10s01',
 '20231220-080-07_0071-0077_t07s05',
 '20240123-070-02_0150-0157_t06s01']

In [7]:
# load prompt embeddings
q_embeddings = [embedding_dataset[key] for key in q_files]
q_embeddings[0]

<HDF5 dataset "20240122-055-03_0052-0060_t10s01": shape (768,), type "<f8">

In [19]:
embedding_dataset.values()

ValuesViewHDF5(<HDF5 file "clamp_embeddings.h5" (mode r)>)

In [23]:
# find best and worst n matches for prompts
n_matches = 3
indices = []  # To store the indices of the best and worst matches
for q_emb in q_embeddings:
    similarities = []
    for idx, dataset in enumerate(embedding_dataset.values()):
        sim = 1 - cosine(q_emb, dataset)
        similarities.append((sim, idx))

    similarities.sort(key=lambda x: x[0], reverse=True)  # Highest first
    best_indices = [idx for _, idx in similarities[:n_matches]]  # Top 3
    worst_indices = [idx for _, idx in similarities[-n_matches:]]  # Bottom 3

    indices.append((best_indices, worst_indices))

0 <HDF5 dataset "20231220-080-01_0000-0005_t00s00": shape (768,), type "<f8">
1 <HDF5 dataset "20231220-080-01_0000-0005_t00s01": shape (768,), type "<f8">
2 <HDF5 dataset "20231220-080-01_0000-0005_t00s02": shape (768,), type "<f8">
3 <HDF5 dataset "20231220-080-01_0000-0005_t00s03": shape (768,), type "<f8">
4 <HDF5 dataset "20231220-080-01_0000-0005_t00s04": shape (768,), type "<f8">
5 <HDF5 dataset "20231220-080-01_0000-0005_t00s05": shape (768,), type "<f8">
6 <HDF5 dataset "20231220-080-01_0000-0005_t00s06": shape (768,), type "<f8">
7 <HDF5 dataset "20231220-080-01_0000-0005_t00s07": shape (768,), type "<f8">
8 <HDF5 dataset "20231220-080-01_0000-0005_t01s00": shape (768,), type "<f8">
9 <HDF5 dataset "20231220-080-01_0000-0005_t01s01": shape (768,), type "<f8">
10 <HDF5 dataset "20231220-080-01_0000-0005_t01s02": shape (768,), type "<f8">
11 <HDF5 dataset "20231220-080-01_0000-0005_t01s03": shape (768,), type "<f8">
12 <HDF5 dataset "20231220-080-01_0000-0005_t01s04": shape (76

KeyboardInterrupt: 

In [None]:
embedding_dataset.close()