In [None]:
import numpy as np 
from sklearn import datasets
from scipy.cluster.hierarchy import ward, fcluster
from scipy.stats import multivariate_normal
from scipy.spatial import ConvexHull
import open3d as o3d

In [3]:
random_state=71

In [4]:
# create a simple dataset

n_samples = 500
n = 4

X, y = datasets.make_blobs(centers = n, n_features = 3, n_samples=n_samples, random_state=random_state, center_box=(-20.0,20.0))
#transformation = [[1.5, -1, 1], [-1.5, 2, 1], [-1.5, 2, 1]]
#X = np.dot(X, transformation)

In [None]:
class KernelParameters:
    '''
    A class to store the kernel parameters of all kernels in the fault network

    mean: np.array
        Array of shape (n_kernels, n_dimensions) that contains the mean for every kernel. 

    cov: np.array
        Array of shape (n_kernels, n_dimensions, n_dimensions) that contains the covariance matrices for every kernel. 
    
    weight: np.array
        Array of shape (n_kernels,) that contains the weight of every kernel.

    bbox: np.array
        An array of shape (n_kernels,8,3) that contains the corners of the bounding box.
    
    is_bkg: np.array
        A boolean array of shape(n_kernels,) that indicates which kernel is a background kernel
    '''

    def __init__(self, n_dim: int = 3):

        self.n_kernels = 0
        self.n_dim = n_dim
        
        self.mean = np.zeros((0,n_dim), dtype=np.float128)
        self.cov = np.zeros((0,n_dim,n_dim), dtype=np.float128)
        self.weight = np.zeros((0,), dtype=np.float128)
        self.bbox = np.zeros((0,2**n_dim,n_dim), dtype=np.float128)
        self.is_bkg = np.zeros((0,), dtype=bool)
    
    def get_dim(self):
        return self.n_dim
    
    def get_n_kernels(self):
        return self.n_kernels
    
    def add_kernels(self, mean, cov, weight, bbox, is_bkg):
        '''Add new kernels to the kernel configuration'''

        self.mean = np.concatenate([self.mean, mean], axis=0)
        self.cov = np.concatenate([self.cov, cov], axis=0)
        self.weight = np.concatenate([self.weight, weight], axis=0)
        self.bbox = np.concatenate([self.bbox, bbox], axis=0)
        self.is_bkg = np.concatenate([self.is_bkg, is_bkg], axis=0)

        self.n_kernels += len(weight)

    def concatenate_parameters(self, kernelp):
        '''Add new kernels from another KernelParameters object'''

        self.add_kernels(kernelp.mean, kernelp.cov, kernelp.weight, kernelp.bbox, kernelp.is_bkg)
    
    def modify_kernels(self, kernel_idx, mean, cov, weight, bbox, is_bkg):
        '''
        Modify a single or multiple existing kernels in the kernel configuration.
        Make sure that the type of kernel_idx and the shape of the other arguments match.
        '''

        self.mean[kernel_idx] = mean
        self.cov[kernel_idx] = cov
        self.weight[kernel_idx] = weight
        self.bbox[kernel_idx] = bbox
        self.is_bkg[kernel_idx] = is_bkg
    
    def modify_weight(self, weight):
        '''
        Modify the weights of all kernels in the kernel configuration
        '''

        if self.weight.shape != weight.shape:
            raise ValueError('Input weight matrix must have same shape as current weight matrix')

        self.weight = weight

    def get(self, field: str):
        '''
        Get a specific field from the kernel parameters.
        '''

        if field == 'm':
            return self.mean
        elif field == 'c':
            return self.cov
        elif field == 'w':
            return self.weight
        elif field == 'b':
            return self.bbox
        elif field == 'ib':
            return self.is_bkg
        else:
            raise ValueError(f'Unknown field {field}')
    
    def delete_kernels(self, kernel_idx):
        '''
        Remove a single or multiple kernels from the kernel configuration.

        Parameters
        -----------
        kernel_idx: int | array_like
            Integer, an integer index or a boolean mask of length n_kernels that indicates which kernels to keep

        '''

        if hasattr(kernel_idx, "__len__") and isinstance(kernel_idx[0], bool):
            self.mean = self.mean[kernel_idx]
            self.cov = self.cov[kernel_idx]
            self.weight = self.weight[kernel_idx]
            self.bbox = self.bbox[kernel_idx]
            self.is_bkg = self.is_bkg[kernel_idx]

            self.n_kernels = sum(kernel_idx)

        else:
            self.mean = np.delete(self.mean, kernel_idx, axis=0)
            self.cov = np.delete(self.cov, kernel_idx, axis=0)
            self.weight = np.delete(self.weight, kernel_idx, axis=0)
            self.bbox = np.delete(self.bbox, kernel_idx, axis=0)
            self.is_bkg = np.delete(self.is_bkg, kernel_idx, axis=0)

            self.n_kernels -= len(kernel_idx) if hasattr(kernel_idx, "__len__") else 1
    
    def get_kernels(self, kernel_idx = None):
        '''
        Extract the kernel parameters of a single or multiple kernels.
        If kernel_idx = None, return all kernels. 
        If kernel_idx is an integer, the outer dimension of the return values is flattened.
        '''

        if kernel_idx is None:
            kernel_idx = range(self.n_kernels)

        return self.mean[kernel_idx], self.cov[kernel_idx], self.weight[kernel_idx], self.bbox[kernel_idx], self.is_bkg[kernel_idx]
    
    def is_background(self, kernel_idx):

        return self.is_bkg[kernel_idx]


In [4]:
def get_capacity_clusters(X: np.array, 
                          min_sz_cluster: int = 4, 
                          min_n_merges: int = 4
                          )->np.array:
    '''
    Get the cluster assignment with the largest number of valid clusters based on ward linkage.
    Iteratively cuts the tree from the bottom until the number of valid clusters does not increase anymore.
    
    Parameters
    -----------
    X: np.array
        The data to compute the cluster assignment for as an array of size (n_samples x n_dimensions)
    
    min_sz_cluster: int
        The threshold on the cluster size for a cluster to be considered valid, default = 4

    min_n_merges: int
        The number of cluster merging steps that can be skipped by the algorithm, default = 4


    Returns
    --------
    capacity_labels: np.array
        The cluster labels as an array of length n_samples
    '''

    min_n_merges = max(min_sz_cluster, min_n_merges)

    link_tree = ward(X)

    capacity = 0
    capacity_labels = np.zeros(len(X))

    for n_merges in range(min_n_merges,len(X)):

        # get cluster labels and determine the cluster sizes
        labels = fcluster(link_tree, link_tree[n_merges-1, 2], "distance")-1
        uq_labs , counts = np.unique(labels, return_counts = True)
        
        n_clusters = sum(counts >= min_sz_cluster)

        # check whether capacity has improved
        if n_clusters >= capacity:
            capacity = n_clusters
            capacity_labels = labels
        
        # stop when all points are included in valid clusters
        elif n_clusters == len(uq_labs):
            break


    return capacity_labels


In [22]:
def fit_gaussian_kernels(X: np.array,
                         cluster_labels: np.array,
                         min_sz_cluster: int = 4
                         )->tuple:
    '''
    Fit a Gaussian kernel to every valid cluster. Fits mean, covariance and weight for every kernel.
    If there are points that are not in any valid cluster, fit a uniform background kernel.

    Parameters
    -----------
    X: np.array
        An array of observations. Must be of shape (n_samples, n_dimensions).
    
    cluster_labels: np.array
        The cluster assignment of every observation. Must be of length n_samples
    
    min_sz_cluster: int
        The threshold on the cluster size for a cluster to be considered valid, default = 4

    Returns
    -------
    kernels: KernelParameters
        The kernel configuration after fitting Gaussian kernels to valid clusters

    '''

    if len(X) != len(cluster_labels):
        raise ValueError(f'Number of datapoints {len(X)} does not match number of labels {len(cluster_labels)}')
    
    # determine dataset parameters
    n_dim = X.shape[1]
    n_points = X.shape[0]
    uq_clusters, cluster_szs = np.unique(cluster_labels, return_counts = True) # all clusters and sizes
    n_clusters = sum(cluster_szs >= min_sz_cluster) # number of valid clusters
    valid_clusters = uq_clusters[cluster_szs >= min_sz_cluster] # labels of valid clusters
    fit_background = n_clusters < len(uq_clusters)

    mean = np.zeros((n_clusters+int(fit_background), n_dim))
    covar = np.repeat([np.eye(n_dim)], n_clusters+int(fit_background), axis=0)
    weight = np.zeros(n_clusters+int(fit_background))
    bbox = np.all((n_clusters+int(fit_background), 8, 3), np.nan)
    is_bkg = np.all(n_clusters+int(fit_background), False)

    for i in range(n_clusters):

        id = valid_clusters[i]
        X_curr = X[cluster_labels == id] 

        # compute cluster mean
        mean[i,:] = np.mean(X_curr, axis=0)

        # compute cluster covariance
        covar[i,:,:] = np.cov(X_curr, rowvar=False)

        # compute cluster weight
        weight[i] = len(X_curr)/n_points
    
    if fit_background and n_dim == 3:
        # fit the background kernel

        X_bkg = X[[l not in valid_clusters for l in cluster_labels]]
        bbox[-1,:,:], center = get_minimum_bbox(X_bkg)
        is_bkg[-1] = True

        # set background mean
        mean[-1,:] = center

        # compute background covariance
        #FIXME: sqrt(12)*stddev in paper but in the implementation it's sqrt(12)*variance?
        covar[-1,:] = np.cov(bbox, rowvar=False)*3.5 

        # compute background weight
        weight[-1] = len(X_bkg)/n_points
    
    kernels = KernelParameters()
    kernels.add_kernels(mean, covar, weight, bbox, is_bkg)


    return kernels

        

def get_minimum_bbox(X: np.array)->tuple:

    '''
    Calculate the minimum volume oriented bounding box for the points in X

    Parameters
    -----------
    X: np.array
        An array containing a point cloud of observations. Has to be of shape (n_samples, 3)

    Returns
    --------
    corners: np.array
        The 8 corner points of the bounding box as an array of shape (8,3)

    center: np.array
        The center point of the bounding box as an array of shape (3,)
    
    '''

    if X.shape[1] != 3:
        raise ValueError(f'Data has to be 3-dimensional, was {X.shape[1]}-dimensional')

    # create a point cloud object from the data
    cloud = o3d.geometry.PointCloud()
    cloud.points = o3d.utility.Vector3dVector(X)

    # get the corners and center of the minimum bounding box
    bbox = cloud.get_minimal_oriented_bounding_box()

    corners = np.asarray(bbox.get_box_points())
    center = bbox.get_center()

    return corners, center


def get_gaussian_bbox(mean: np.array,
                      cov: np.array,
                      n_var: float)->tuple:
    '''
    Calculate the minimum volume oriented bounding box for the points within sqrt(n_var) standard deviations of the mean of a 3D-Gaussian.

    Parameters
    ----------

    mean: np.array
        Mean vector of the Gaussian. Must be of shape (3,).
    
    cov: np.array
        Covariance matrix of the Gaussian. Must be of shape (3,3).
    
    n_var: float
        How many variances from the mean of the Gaussian to consider for the bounding box.

    Returns
    --------

    corners: np.array
        The 8 corner points of the bounding box as an array of shape (8,3)

    mean: np.array
        The center point of the bounding box as an array of shape (3,). Is equivalent to mean vector of the Gaussian.

    '''

    if mean.shape != (3,) or cov.shape != (3,3):
        raise ValueError('Gaussian must be 3-dimensional')
    
    evals, evecs = np.sqrt(np.linalg.eigh(n_var*cov))

    # determine lengths of the bbox edges
    l_sides = 0.5*np.sqrt(evals) 
    
    # get an unrotated cuboid centered around 0 of the correct size
    corners = np.array([[-1, 1, 1, -1, -1, 1, 1, -1],
                        [1, 1, 1, 1, -1, -1, -1, -1],
                        [-1, -1, 1, 1, 1, 1, -1, -1]])
    
    corners *= np.expand_dims(l_sides, axis=1)
    
    # rotate and shift the cuboid to the correct position and get the correct shape
    corners = (evecs @ corners).T + np.tile(mean, [8,1])

    return corners, mean 

In [None]:

def inhull(X: np.array,
           hull_pts: np.array,
           eps: float = np.finfo(float).eps
          )->np.array:
  
  '''
  Check for all points in X whether they lie in the convex hull defined by hull_pts.
  Adapted from https://stackoverflow.com/questions/31404658/check-if-points-lies-inside-a-convex-hull

  Parameters
  -----------
  X: np.array
    An array of observations. Must be of shape (n_samples, n_dimensions).
  
  hull_pts: np.array
    An array of points from which the convex hull is determined. Must be of shape (n_points, n_dimensions).
  
  eps: np.float32
    The tolerance to be used when checking whether a given point is inside the hull.
    Choose > 0 to avoid numerical issues, default = np.finfo(float).eps

  Returns
  --------
  in_hull: np.array
    A boolean array of shape (n_samples,) indicating which points in X are inside the hull
    
  '''

  if X.shape[1] != hull_pts.shape[1]:
    raise ValueError(f'Hull points and test points must have the same dimensions.')
  
  hull = ConvexHull(hull_pts)

  # A is shape (f, d) and b is shape (f, 1).
  A, b = hull.equations[:, :-1], hull.equations[:, -1]

  # The hull is defined as all points x for which Ax + b <= 0.
  in_hull = np.array([np.all(A @ x + b < eps) for x in X])

  return in_hull

In [None]:
def assign_to_kernel(X: np.array,
                     kernels: KernelParameters,
                     min_sz_cluster: int = 4
                     ):
    '''
    Perform a single step of expectation maximization to assign each data point to its kernel.

    Parameters
    ----------
    X: np.ndarray
        An array of observations. Must be of shape (n_samples, n_dimensions).

    kernels: KernelParameters
        The current kernel configuration
    
    min_sz_cluster: int
        The threshold on the cluster size for a cluster to be considered valid, default = 4

    Returns
    --------
    kernels: KernelParameters
        The updated kernel configuration
    
    k_assign: np.ndarray
        An array of shape (n_samples,) which indicates the kernel assignment of every point in X

    kernel_prob: np.ndarray
        An array of shape (n_samples, n_kernels) that contains the probability of each datapoint under each kernel
    '''

    n_points = len(X)

    kernel_prob = get_kernel_prob(X, kernels)
    k_assign = np.argmax(kernel_prob, axis=1)

    del_ker = []

    for i in range(kernels.get_n_kernels()):

        msk = k_assign == i
        cluster_sz = sum(msk)

        if cluster_sz >= min_sz_cluster or (kernels.is_background(i) and cluster_sz > 0):
            
            # update kernel parameters
            cov = np.cov(X[msk,:], rowvar=False)
            mean = np.mean(X[msk,:], axis=0)
            weight = cluster_sz / n_points

            if not kernels.is_background(i):
                # calculate bbox for every non-background kernel

                # bbox defined by the points in the cluster
                bbox_pts, _ = get_minimum_bbox(X[msk,:])

                # bbox defined by the Gaussian kernel
                bbox_gauss, _ = get_gaussian_bbox(mean[i,:], cov[i,:,:], 12)

                # bbox is minimum bbox of the union of all corner points
                bbox, _ = get_minimum_bbox(np.concatenate((bbox_pts, bbox_gauss), axis=0))
            
            kernels.modify_kernels(i, mean, cov, weight, bbox, kernels.is_background(i))

        else:
    
            del_ker.append[i]

    # delete non-background kernels with too few points
    kernels.delete_kernels(del_ker)

    # update probabilities and kernel assignment
    kernel_prob = get_kernel_prob(X, kernels)
    k_assign = np.argmax(kernel_prob, axis=1)

    return kernels, k_assign, kernel_prob



def get_kernel_prob(X: np.array,
                     kernels: KernelParameters
                     ):
    '''
    Determine the probability of each data point under each kernel

    Parameters
    ----------
    X: np.ndarray
        An array of observations. Must be of shape (n_samples, n_dimensions).

    kernels: KernelParameters
        The current kernel configuration
        

    Returns
    --------
    kernel_prob: np.array
        An array of shape (n_samples, n_kernels)
    '''

    kernel_prob = np.zeros((len(X), kernels.get_n_kernels()), dtype=np.float128)

    # calculate the the probability of each point under each kernel
    for i in range(kernels.get_n_kernels()):

        mean, cov, weight, bbox, is_bkg = kernels.get_kernels(i)

        if not is_bkg:

            kernel_prob[:,i] = weight[i] * multivariate_normal(X, mean=mean, cov=cov)
    
        else:
            # determine points in background kernel
            eps = 1.e^-10*np.mean(np.abs(bbox))
            bkg_pts = inhull(X, bbox, eps)
            
            # calculate probability of being in the background
            kernel_prob[bkg_pts,i] = weight*1/np.prod(np.sqrt(np.linalg.eigvalsh(cov)), axis=0)

    return kernel_prob


def get_bic(kernel_prob: np.array,
            n_kernels: int = None):
    '''
    Calculate the Bayesian Information Criterion of the dataset from the probability of each datapoint under each kernel

    Parameters
    ----------
    kernel_prob: np.array
        An array of shape (n_samples, n_kernels) that contains the probability of every sample under every kernel
    
    n_kernels: int
        Optional number of kernels to use. Can be specified if kernel_prob is already a cumulative probability, default = None
    
    Returns
    --------
    bic: float
        The BIC value of the current kernel configuration
    '''
    
    n_points = kernel_prob.shape[0]
    
    if n_kernels is None:
        n_kernels = kernel_prob.shape[1]

    # get cumulative probability per data point
    total_prob = np.sum(kernel_prob, axis = 0)
    total_prob[total_prob < np.finfo(np.float128).eps] = np.finfo(np.float128).eps

    # calculate BIC
    bic = np.sum(-np.log(total_prob))+0.5*(10*n_kernels-1)*np.log(n_points)

    return bic

In [None]:
def have_overlap(bbox1: np.array,
                 bbox2: np.array):
    '''
    Check whether two bounding boxes overlap.

    Parameters
    ----------
    bbox1: np.array
        The corner points of the first bounding box as an array of shape (n_points, n_dimensions)
    
    bbox1: np.array
        The corner points of the second bounding box as an array of shape (n_points, n_dimensions)
    
    Returns
    --------
    overlap: bool
        Indicates whether the two bounding boxes overlap

    '''
    
    overlap = inhull(bbox1, bbox2) or inhull(bbox2, bbox1)

    return overlap



def merge_single_pair(kernel_pair: KernelParameters,
                      keep_wgt: bool = True):
    '''
    Perform a Gaussian merge on a pair of kernels

    Parameters
    ----------
    kernel_pair: KernelParameters
        Contains the two kernels for which the merged kernel should be computed
    
    keep_wgt: bool
        Indicates whether to compute the weight of the merged kernel or set it to 1, default = True

    Returns
    --------
    mean: np.array
        The mean vector of the merged kernel, has shape (1, n_dim)
    
    cov: np.array
        The covariance matrix of the merged kernel, has shape (1, n_dim, n_dim)
    
    weight: np.array
        The weight of the merged kernel, has shape (1,)
    
    bbox: np.array
        The bounding box of the merged kernel, has shape(1, 2**n_dim, n_dim)
    
    is_bkg: np.array
        Indicates whether the merged kernel is a background kernel, has shape (1,)

    '''
    if kernel_pair.get_n_kernels() != 2:
        raise ValueError('Kernel configuration is not a pair')

    # unpack the kernel parameters
    m1, c1, w1, b1, ib1 = kernel_pair.get_kernels(0)
    m2, c2, w2, b2, ib2 = kernel_pair.get_kernels(1)
    
    # calculate merged bbox and weight
    weight = w1+w2
    bbox, center = get_minimum_bbox(np.concatenate((b1,b2), axis=0))
    is_bkg = ib1 or ib2

    # calculate new mean and covariance
    if is_bkg:
        mean = center
        cov = np.cov(bbox, rowvar=False)*3.5
    
    else:
        mean = 1/weight * (w1 * m1 + w2 * m2)
        cov = (w1/weight)*(c1 + np.outer((m1-mean),(m1-mean))) + (w2/weight)*(c2 + np.outer((m2-mean),(m2-mean)))


    if not keep_wgt:
        weight = 1

    return np.array([mean]), np.array([cov]), np.array([weight]), np.array([bbox]), np.array([is_bkg])



def get_disjoint_pairs(rows: np.array,
                       cols: np.array,
                       score: np.array):
    '''
    Get the disjoint pairs from a set of row-column index pairs.
    The first occurrence of an index is decided based on descending score.

    Parameters
    -----------
    rows: np.array
        An array of row indices
    
    cols: np.array
        An array of column indices in the same order as rows
    
    score: np.array
        An array of scores that determine the precedence of a pair in the same order as rows.
        Higher score is better and pairs with non-positive scores are cut.
    
    Returns
    -------
    rows, columns, score: np.array
        The row and column indices and scores of the unique pairs sorted by descending score
    
    idx_sort: np.array
        An index that can be used to extract the matching elements of an array with the same order
        as the input rows or columns
    '''

    # get all pairs with improvement
    idx_imp = score > 0

    rows = rows[idx_imp]
    cols = cols[idx_imp]
    score = np.array(score)[idx_imp]
    
    # sort the pairs by descending score
    idx_sort = np.argsort(score)[::-1]
    rows = rows[idx_sort]
    cols = cols[idx_sort]

     
    unique_rows = []
    unique_cols = []

    # only keep disjoint pairs
    for i in range(len(score)):
        if (rows[i] in unique_rows or unique_cols) or (cols[i] in unique_cols or unique_rows):
            score = np.delete(score, i)
            idx_sort = np.delete(idx_sort, i)
        else:
            unique_rows.append(rows[i])
            unique_cols.append(cols[i])

    return np.array(unique_rows), np.array(unique_cols), score, idx_sort


def merge_kernel_assignment(kernel_assign: np.ndarray,
                            kernel1: np.ndarray, 
                            kernel2: np.ndarray):
    
    '''
    Map the old kernel labels to the new kernel labels.
    Kernel labels coincide with position of the kernel in the KernelParameters object.

    Parameters
    -----------
    kernel_assign: np.ndarray
        The old kernel assignment of every data point. Has shape (n_samples,)
    
    kernel1, kernel2: np.ndarray
        Contain the labels of the first and second kernel of every merged pair in the same order


    Returns
    --------
    kernel_assign: np.ndarray
        The new kernel assignment of ever data point

    '''

    # create a mapping from old kernel labels to new kernel labels
    old2new = np.arange(len(np.unique(kernel_assign)))
    
    msk_keep = np.full(len(old2new)+1, True)
    msk_keep[np.concatenate([kernel1, kernel2])] = False

    # add new labels of the old kernels 
    old2new = old2new[msk_keep]
    old2new = dict(zip(old2new, range(len(old2new))))

    # add new labels of the merged kernels
    for i in range(len(kernel1)):

        old2new[kernel1[i]] = len(old2new)+i
        old2new[kernel2[i]] = len(old2new)+i


    kernel_assign = np.array([old2new[l] for l in kernel_assign])

    return kernel_assign


def merge_clusters(X: np.ndarray,
                   kernels: KernelParameters,
                   kernel_assign: np.ndarray,
                   init_prob: np.ndarray,
                   gain_mode: str = 'global'):
    
    '''
    Merge clusters iteratively 
    '''
    
    modes = ['local', 'global']

    if gain_mode not in modes:
        raise ValueError(f'Unknown gain mode {gain_mode}. Must be one of {modes}.')

    bic_init = get_bic(init_prob)
    tot_prob_init = np.sum(init_prob, axis=0, keepdims=True)

    gain = np.zeros((kernels.get_n_kernels(),kernels.get_n_kernels()), dtype=np.float128)

    while True:

        # get indices of all relevant kernel pairs
        rows,cols = np.mask_indices(len(gain), lambda x,k: np.logical_and(np.triu(x,k),~np.isnan(gain)),1)
        
        # check whether pairs are merging candidates
        del_idx = np.full((len(rows)), False)
        for idx, (r, c) in enumerate(zip(rows,cols)):

            _, __, ___, r_bbox, r_bkg = kernels.get_kernels(r)
            _, __, ___, c_bbox, c_bkg = kernels.get_kernels(c)

            del_idx[idx] = r_bkg or c_bkg

            if not del_idx[idx]:
                del_idx[idx] = not have_overlap(r_bbox, c_bbox)
        
        # remove all non-candidates
        gain[rows[del_idx],cols[del_idx]] = np.nan
        gain[cols[del_idx],rows[del_idx]] = np.nan

        rows = rows[~del_idx]
        cols = cols[~del_idx]
        
        if len(rows)==0:
            break


        new_kernels = KernelParameters()
        p_merged = []
        p_separate = []
        merge_score = []

        for idx, (r, c) in enumerate(zip(rows,cols)):

            # get the pair kernels
            old_kerns = KernelParameters()
            old_kerns.add_kernels(kernels.get_kernels([r,c]))

            new_kern = KernelParameters()
                
            if gain_mode == 'local':
                
                # only consider contributions from points assigned to the kernel pair
                X_local = X[np.logical_or(kernel_assign == r, kernel_assign == c)]

                # calculate parameters of the merged kernel
                new_kern.add_kernels(merge_single_pair(old_kerns, keep_w = False))

                # get probability and bic under the merged kernel
                prob_merged = get_kernel_prob(X_local, new_kern)
                bic_merged = get_bic(prob_merged)

                # modify weight
                old_kerns.modify_weight(np.array([sum(kernel_assign==r),sum(kernel_assign==c)], dtype=np.float128)/len(X_local))

                # get probability and bic under the two separate kernels
                prob_separate = get_kernel_prob(X_local, old_kerns)
                bic_separate = get_bic(prob_separate)
                
                # calculate the information gain from merging
                merge_score.append(bic_separate-bic_merged)
        
        
            else:
                # calculate parameters of the merged kernel
                new_kern.add_kernels(merge_single_pair(old_kerns, keep_w = True))

                # get cumulative probability under the merged kernel
                tot_prob_merged = np.sum(get_kernel_prob(X, new_kern), axis=0, keepdims=True)
                p_merged.append(tot_prob_merged)

                # get cumulative probability under the two separate kernels
                tot_prob_separate = np.sum(get_kernel_prob(X, old_kerns), axis=0, keepdims=True)
                p_separate.append(tot_prob_separate)
                
                # calculate the cumulative probability after merging and bic
                sum_tot_prob = tot_prob_init - tot_prob_separate + tot_prob_merged
                bic_merged = get_bic(sum_tot_prob, n_kernels = kernels.get_n_kernels()-1)

                # calculate the information gain from merging
                merge_score.append(bic_init-bic_merged)

                # save the new kernel
                new_kernels.concatenate_parameters(new_kern)
        
        gain[rows, cols] = merge_score
        gain[cols, rows] = merge_score

        if np.nanmax(merge_score) <= 0:
            
            if gain_mode == 'local':
            
                # switch to global optimization after local is finished
                gain_mode = 'global'
                continue
        
            else:
                break
        
        msk_sort = np.full(len(rows), False)

        # get unique pairs
        rows, cols, merge_score, idx_sort = get_disjoint_pairs(rows, cols, merge_score)
        n_pairs = len(rows)

        # merge unique pairs
        if gain_mode == 'local':

            # merge clusters with global method
            p_merged = np.zeros(n_pairs)
            p_separate = np.zeros(n_pairs)
            
            for r,c in zip(rows, cols):

                # get the pair kernels
                old_kerns = KernelParameters()
                old_kerns.add_kernels(kernels.get_kernels([r,c]))
                
                # calculate parameters of the merged kernel
                new_kern = KernelParameters()
                new_kern.add_kernels(merge_single_pair(old_kerns, keep_w = True))

                # get cumulative probability under the merged kernel
                p_merged += np.sum(get_kernel_prob(X, new_kern), axis=0, keepdims=True)

                # get cumulative probability under the two separate kernels
                p_separate += np.sum(get_kernel_prob(X, old_kerns), axis=0, keepdims=True)
                
                new_kernels.concatenate_parameters(new_kern)
            
        else:

            # transform sort index to boolean mask to use on KernelParameters
            msk_sort[idx_sort] = True

            # can reuse results from above
            new_kernels = new_kernels.delete_kernels(msk_sort)
            p_merged = sum(p_merged[msk_sort])
            p_separate = sum(p_separate[msk_sort])

        # update the initial probability and get the bic of all the merging events
        tot_prob_init = tot_prob_init - p_separate + p_merged
        bic_init = get_bic(tot_prob_init, n_kernels=kernels.get_n_kernels()-new_kernels.get_n_kernels())

        if np.isinf(bic_init):
            break


        # delete the now merged kernels
        idx_del = np.concatenate([rows, cols], axis=0)

        kernels.delete_kernels(idx_del)
        gain = np.delete(gain, idx_del, axis=0)
        gain = np.delete(gain, idx_del, axis=1)

        # add the new kernels
        kernels.concatenate_parameters(new_kernels)

        # modify kernel assignment
        if gain_mode == 'local':
            kernel_assign = merge_kernel_assignment(kernel_assign, rows, cols)
            
        gain = np.zeros((kernels.get_n_kernels(),kernels.get_n_kernels()), dtype=np.float128)
    
    return kernels, kernel_assign



In [None]:
def cut_chunks(X: np.ndarray,
               n_chunks: int):
    '''
    Cut X into n_chunks chunks based on its agglomerative tree

    Parameters
    ----------
    X: np.ndarray
        An array of observations. Must be of shape (n_samples, n_dimensions).
    
    n_chunks: int
        The number of chunks that the data should be cut into


    Returns
    --------
    chunk_labs: np.ndarray
        An array of shape (n_samples,) that contains the chunk assignment of each datapoint
    '''

    if n_chunks == 1:
        chunk_labs = np.zeros(len(X))
    
    else:
        # compute the labels from the agglomerative link tree
        link_tree = ward(X)
        chunk_labs = fcluster(link_tree, link_tree[-n_chunks, 2], "distance")-1

    return chunk_labs



def run_fault_reconstruction(X: np.ndarray,
                       min_sz_cluster: int,
                       n_chunks: int = 1,
                       gain_mode: str = 'global'
                       ):
    '''
    Run the fault reconstruction algorithm (Kamer 2020)
    
    Parameters
    -----------
    X: np.ndarray
        The data to compute the cluster assignment for as an array of size (n_samples x n_dimensions)
    
    min_sz_cluster: int
        The threshold on the cluster size for a cluster to be considered valid, default = 4

    n_chunks: int
        Number of chunks that the data is cut into before performing the algorithm, default = 1

    Returns
    --------
    all_labels: np.ndarray
        The cluster assignment of every data point
    '''

    # get the chunk partition
    chunk_labs = cut_chunks(X, n_chunks)

    all_kernels = KernelParameters()
    all_labels = np.empty(0, dtype=int)
    lab_offset = 0

    for chunk_id in range(n_chunks):

        msk = chunk_labs == chunk_id

        # get the capacity clusters based on the agglomerative tree
        capacity_labs = get_capacity_clusters(X[msk], min_sz_cluster)
        
        # fit the kernels
        kernels = fit_gaussian_kernels(X[msk], capacity_labs, min_sz_cluster)

        # assign points and update kernels with EM
        kernels, cluster_labs, kernel_prob = assign_to_kernel(X[msk], kernels, min_sz_cluster)

        # run the kernel merging algorithm
        kernels, cluster_labs = merge_clusters(X[msk], kernels, cluster_labs, kernel_prob, gain_mode)

        # reassign the points with EM
        kernels, cluster_labs, kernel_prob = assign_to_kernel(X[msk], kernels, min_sz_cluster)

        # shift and persist the cluster labels of the current chunk
        cluster_labs = cluster_labs + lab_offset
        lab_offset += kernels.get_n_kernels()
        all_labels = np.concatenate([all_labels, cluster_labs], axis=0)

        # scale the kernel weights
        _, cluster_sz = np.unique(np.sort(cluster_labs), return_counts=True)
        kernels.modify_weight(kernels.get('w')*cluster_sz/sum(msk))

        # persist the kernels of the current chunk
        all_kernels.concatenate_parameters(kernels)
    

    # merge the kernels of all chunks
    kernel_prob = get_kernel_prob(X, all_kernels)
    all_kernels, all_labels = merge_clusters(X, all_kernels, all_labels, kernel_prob, gain_mode)

    return all_labels
