In [None]:
'''
The approach leverages the strengths of both FAISS and fastdtw to efficiently handle the search and refinement process:

FAISS: Efficiently retrieves a large initial candidate pool based on the last step of the query sequence. This is fast and handles high-dimensional data well.
fastdtw: Iteratively refines the candidate pool by considering increasingly longer subsequences, making it feasible to handle long sequences without getting bogged down by the complexity of full DTW computations for all candidates.
This combination ensures that we maintain efficiency while still getting accurate results.

Summary of the Approach
Initial Candidate Retrieval: Use FAISS to quickly retrieve an initial set of candidates based on the last step of the query sequence.
Iterative Refinement: Use fastdtw to iteratively refine the candidate pool. Start with the last step of the query sequence and progressively consider longer subsequences.
At each step, compute the DTW distance for the current subsequence with the candidates.
Sort the candidates based on the DTW distance and remove the bottom 10%.
Continue until the candidate set is sufficiently small.
Final Selection: Once the candidate pool is sufficiently small, compute the full DTW distance for the remaining candidates and select the one with the smallest distance.
'''


# fast way
import numpy as np
from fastdtw import fastdtw
from scipy.spatial.distance import euclidean
import faiss
import os

def create_index(vectors):
    """
    Creates a FAISS index from a given set of vectors.
    Args:
    vectors (numpy.ndarray): A numpy array of shape (nb, d) where 'nb' is the number of base vectors and 'd' is the dimension.
    Returns:
    faiss.IndexFlatL2: The created FAISS index.
    """
    d = vectors.shape[1]  # dimension of the vectors
    index = faiss.IndexFlatL2(d)  # L2 distance index
    index.add(vectors)
    return index

def faiss_search(index, query, k=1000):
    """
    Search for k nearest neighbors in the FAISS index.
    Args:
    index (faiss.IndexFlatL2): The FAISS index.
    query (numpy.ndarray): The query vector.
    k (int): Number of nearest neighbors to return.
    Returns:
    np.ndarray: Indices of the nearest neighbors.
    np.ndarray: Distances to the nearest neighbors.
    """
    D, I = index.search(query, k)
    return I, D

def fastdtw_distance(seq1, seq2):
    """
    Computes the DTW distance between two sequences using the fastdtw algorithm.
    Args:
    seq1 (numpy.ndarray): The first sequence.
    seq2 (numpy.ndarray): The second sequence.
    Returns:
    float: The DTW distance between the sequences.
    """
    distance, _ = fastdtw(seq1, seq2, dist=euclidean)
    return distance

context_length = 10
sample = 100000

# Generate Sample Data
sequences = [np.random.rand(context_length, 256 * 4 * 4) for _ in range(sample)]
flattened_sequences = np.array(sequences).reshape(context_length * sample, -1)  # Flatten the sequences for FAISS indexing

# Create FAISS Index
index = create_index(flattened_sequences)

# Define a Query Sequence and Iteratively Refine the Candidate List
query_sequence = np.random.rand(context_length, 256 * 4 * 4)  # Example query sequence

# Step 1: Use the last step of the query sequence to find initial candidates
last_step_query = query_sequence[-1].reshape(1, -1)
initial_candidates, _ = faiss_search(index, last_step_query, k=1000)  # Start with 1000 candidates

# Convert initial candidates to a set for efficient refinement
candidate_set = set(initial_candidates.flatten())

# Step 2: Iteratively refine the candidate list
for i in range(2, len(query_sequence) + 1):
    subseq_query = query_sequence[-i:].reshape(-1, 256 * 4 * 4)
    distances = []
    for candidate_idx in candidate_set:
        candidate_seq = sequences[candidate_idx].reshape(context_length, 256 * 4 * 4)
        distance = fastdtw_distance(subseq_query, candidate_seq[-i:])
        distances.append((candidate_idx, distance))

    # Sort distances and remove the bottom 10%
    distances.sort(key=lambda x: x[1])
    cutoff_index = max(1, len(distances) - len(distances) // 10)
    candidate_set = set([idx for idx, dist in distances[:cutoff_index]])

    if len(candidate_set) == 1:
        break

# Step 3: Find the most similar sequence in the refined candidate list
highest_match = None
highest_score = float('inf')

for candidate_idx in candidate_set:
    candidate_seq = sequences[candidate_idx].reshape(context_length, 256 * 4 * 4)
    distance = fastdtw_distance(query_sequence, candidate_seq)
    if distance < highest_score:
        highest_score = distance
        highest_match = candidate_idx

print(f"The most similar sequence is at index {highest_match} with a DTW distance of {highest_score}.")
