# Constructing Composite Distance Functions in FAISS
This is an (ongoing) attempt to generate a `FAISS` object using composite norms, in this case, a product of L2 norms over contiguous subsets of given input vectors.

**NOTE**: For a finished implementation of how these composite norms can be applied for Nearest Neighbor applications, please check out the other notebook in this directory. This other notebook implements this metric using `sklearn`.

### Import Block
Here, we make use of `faiss`, which is written in C++, for high-performance similarity searches, in this case, Nearest Neighbors.

In [3]:
import faiss
import numpy as np

### Generate Dataset
We'll use random samples drawn from a MVN distribution as a toy dataset for this problem.

In [27]:
# Generate data
X = np.random.normal(loc=0, scale=1, size=(100000, 20)).astype(np.float32)

# Get dimensions
d1 = 2
d2 = X.shape[1] - d1

# Slice observations and actions
O = X[:, :d1]  # Observation slice
A = X[:, d1:]  # Action slice

# Create index for observations
index_1 = faiss.IndexFlatL2(d1)
index_1.add(O.astype(np.float32))

# Create index for actions
index_2 = faiss.IndexFlatL2(d2)
index_2.add(A.astype(np.float32))

## Compute distances
This code block enumerates how we can compute distances over subsets of inputs by only providing the fit indices (trained on the same contiguous subsets of the input space) with subsets of query vectors.

In [19]:
# Need to convert array to C-order
X_search = np.random.normal(loc=0, scale=1, size=((256, 20))).astype(np.float32)
X_search_1 = X_search[:, :d1]
X_search_2 = X_search[:, d1:]
X_search_1= X_search_1.copy(order="C")
X_search_2= X_search_2.copy(order="C")

# Use the search functionality with FAISS to generate nearest neighbor matches
d1, _ = index_1.search(X_search_1, X.shape[0])
d2, _ = index_2.search(X_search_2, X.shape[0])

### (Under Development) Composite FAISS Nearest Neighbor Class
This class, once finished, will compute a product of norms using flat indices over contiguous subsets of the input vectors.

In [None]:
class CompositeFaissKNeighbors:
    """An implementation of FAISS trees that supports composite similarity
    search for composite kernels. This composite similarity is computed
    separately for both states and actions.

    Parameters:
        k (int): The number of neighbors we consider for the FAISS tree.
    """
    def __init__(self, k=50):
        self.index = None
        self.k = k

    def fit(self, X, d_obs):
        """Function to fit the FAISS tree.

        Parameters:
            X (np.array): Observation array of shape (N, d_obs + d_action),
                where N is the number of samples, d_obs is the dimension of
                the observations, and d_action is the dimension of the actions.
                Note that the array must be of type np.float32.
            d_obs (int): The dimension of the observation space.
        """
        # Get dimensions of state and action spaces
        d_action = X.shape[1] - d_obs

        # Slice observations and actions
        O = X[:, :d_obs]  # Observation slice
        A = X[:, d_obs:]  # Action slice

        # Create index for observations
        self.index_obs = faiss.IndexFlatL2(d_obs)
        self.index_obs.add(O.astype(np.float32))

        # Create index for actions
        self.index_action = faiss.IndexFlatL2(d_action)
        self.index_action.add(A.astype(np.float32))

    def query(self, X, d_obs, k=None):
        """Function to query the neighbors of the FAISS index.

        This method fits the Index corresponding to the FAISS tree with
        neighbor data used for querying.

        Parameters:
            X (np.array): Array of shape (N, D), where N is the number of
                samples, and D is the dimension of the features.
            d_obs (int): The dimension of the observation space.
            k (int): If provided, the number of neighbors to compute. Defaults
                to None, in which case self.k is used as the number of neighbors.

        Returns:
            indices (np.array): Array of shape (N, K), where N is the number of
                samples, and K is the number of nearest neighbors to be computed.
                The ith row corresponds to the k-nearest neighbors of the ith
                sample.
        """
        # Set number of neighbors
        if k is None:  # Use default number of neighbors
            k = self.k

        # Slice observations and actions
        O = X[:, :d_obs]  # Observation slice
        A = X[:, d_obs:]  # Action slice

        self.index_action.compute_distance_subset()
        self.index_action.search_and_reconstruct()
        self.index_action.search(n, x, k=...)

        # Query and return nearest neighbors
        d_obs, _ = self.index_obs.search(O.astype(np.float32), k=k)
        d_action, _ = self.index_action.search(A.astype(np.float32), k=k)
        return indices
