In [1]:
#!/usr/bin/env python

import os
import heapq
import numpy as np
import hnswlib
import time
import matplotlib.pyplot as plt
from connected_components import group_indices
from collections import Counter
from multiprocessing import Pool

KUPPA = 0.4 # rough idea of how many density peaks given C total contigs
TOTAL_CONTIGS = 0
LATENT_DIMENSION = 0
ETA = 0 # 1/latent_dim weight of contigs length
DISTANCE_CUTOFF = 1
NN_SIZE_CUTOFF = 200000
NCPUS = (os.cpu_count()-2 if os.cpu_count() is not None else 8)

def add_selfindex(data_indices, labels, distances):
    """ add self index where missed by hnswlib """
    s2 = time.time()
    # in case query index does not found as first nearest neighbor,
    # insert manually and remove last nearest neighbor
    add_selfinds = np.nonzero(labels.astype(int)[:,0] != data_indices)[0]
    labels[add_selfinds] = np.insert(labels[add_selfinds], 0, add_selfinds, axis=1)[:,:-1]
    distances[add_selfinds] = np.insert(distances[add_selfinds], 0, 0.0, axis=1)[:,:-1]
    print(time.time()-s2, 'self index adding')
    return labels, distances

def get_neighbors(latent, length):
    """ Get neighbors and distance cutoff"""
    s= time.time()
    data_indices = np.arange(TOTAL_CONTIGS)
    k_nn = int(TOTAL_CONTIGS ** KUPPA)

    p = hnswlib.Index(space = 'l2', dim = LATENT_DIMENSION)
    # M - max. number of out-edges ef_construction
    # ef_construction - the size of the dynamic list for the nearest neighbors
    # and controls index_time/accuracy. Important during indexing time
    p.init_index(max_elements = TOTAL_CONTIGS, ef_construction = 200, M = 16)
    p.set_ef(k_nn) # number should be greater than k_nn
    p.set_num_threads(NCPUS)
    p.add_items(latent)

    retry = True
    while retry:
        labels, distances = p.knn_query(latent, k=k_nn)
        cumsum_length = np.cumsum(length[labels], axis=1)
        contigs_with_enough_neighbors = np.nonzero(cumsum_length[:,-1] > NN_SIZE_CUTOFF)[0]
        retry = len(contigs_with_enough_neighbors) / TOTAL_CONTIGS < 0.5
        if retry:
            k_nn *= 2 # adjustable
            p.set_ef(k_nn) # adjustable

    labels, distances = add_selfindex(data_indices, labels, distances)
    dist_cutoff_indices = np.argmax(
            np.cumsum(length[labels], axis=1) > NN_SIZE_CUTOFF, axis=1)
    indices_tofind_dist = np.nonzero(dist_cutoff_indices)[0]
    # for long contigs dist_cutoff_indices would be zero. Hence, add manually
    ind_check_longcontig = np.nonzero(dist_cutoff_indices==0)[0]
    indices_to_add = np.nonzero(length[ind_check_longcontig]> NN_SIZE_CUTOFF)[0]
    indices_tofind_dist = np.concatenate((indices_tofind_dist, \
                                        ind_check_longcontig[indices_to_add]))
    
    # Find distance cutoff
    global DISTANCE_CUTOFF
    DISTANCE_CUTOFF = np.median(
        distances[indices_tofind_dist,dist_cutoff_indices[indices_tofind_dist]])
    print(DISTANCE_CUTOFF, np.sqrt(DISTANCE_CUTOFF), 'distance cutoff')

    valid_indices = [[] for _ in range(TOTAL_CONTIGS)]
    
    # flags to indicate contigs found all nearest neighbors within DISTANCE_CUTOFF
    flags = distances[:,-1] > DISTANCE_CUTOFF
    c_indices = np.where(flags==1)[0]

    def update_nnlists(c_indices, nn_lists, nn_dists):
        # Create a mask for distances less than DISTANCE_CUTOFF
        mask = nn_dists <= DISTANCE_CUTOFF
        # Apply the mask to nn_inds_k
        filtered_inds = np.where(mask, nn_lists, -1).astype(int)  # Using -1 as a placeholder for invalid indices
        for num, i in enumerate(c_indices):
            valid_indices[i] = filtered_inds[num, filtered_inds[num] >= 0]

    update_nnlists(c_indices, labels[flags], distances[flags])

    k_nn_more = k_nn
    remaining_indices = np.where(flags==0)[0]
    
    while len(remaining_indices) > 0:
        k_nn_more *= 2
        p.set_ef(k_nn_more)
        nn_lists, nn_dists = p.knn_query(latent[remaining_indices], k = k_nn_more)
        flags_check = nn_dists[:,-1] > DISTANCE_CUTOFF
        indices_toupdate = np.where(flags_check==1)[0]
        # if indices_toupdate.size > 0:
        if k_nn_more < 10000 and indices_toupdate.size > 0:
            c_indices = remaining_indices[indices_toupdate]
            remaining_indices = remaining_indices[~flags_check]
            update_nnlists(c_indices, nn_lists[indices_toupdate], nn_dists[indices_toupdate])
        else:
            # stop if k_nn_more increases to impede knn_query and \
            # compute distance with all points for remaining contigs to \
            # valid indices
            distances_remaining = np.sum((latent - latent[remaining_indices][:,np.newaxis])**2,axis=2)
            for row_distance, inds in zip(distances_remaining, remaining_indices):
                nn_inds = np.argsort(row_distance)
                nn_inds = nn_inds[np.where(row_distance[nn_inds] <= DISTANCE_CUTOFF)]
                valid_indices[inds] = nn_inds
            break

    del p

    return valid_indices, labels, distances

def get_density_peaks(valid_indices, labels, distances, length):
    """ compute density using k-nearest neighbors """
    s = time.time()

    # # Create a mask for distances less than DISTANCE_CUTOFF
    # mask = distances <= DISTANCE_CUTOFF

    # # Apply the mask to nn_inds_k
    # filtered_inds = np.where(mask, labels, -1).astype(int)  # Using -1 as a placeholder for invalid indices
    # del mask
    # valid_indices = [filtered_inds[c, filtered_inds[c] >= 0] for c in range(TOTAL_CONTIGS)]
    # del filtered_inds

    densities = np.array([np.sum(length[valid_indices[c]] ** ETA) / DISTANCE_CUTOFF for c in range(TOTAL_CONTIGS)])
    
    graph = {}
    nearest = np.arange(TOTAL_CONTIGS)
    density_peaks_flag = np.full(TOTAL_CONTIGS,-1)
    
    for c, valid_inds in enumerate(valid_indices):
        # single member clusters
        if valid_inds.size == 1:
            # print(densities[valid_inds], 'densities of single member', c)
            if densities[valid_inds] >= NN_SIZE_CUTOFF: # long genomic contig as a single cluster
                density_peaks_flag[c] = c
                graph[c] = []
                nearest[c] = c
            else: # short isolated contig
                # maxdens_neighbor = np.argmax(densities[labels[c]])
                # if distances[c][maxdens_neighbor] <= DISTANCE_CUTOFF * 1.5 :
                #     graph[c] = labels[c][maxdens_neighbor]
                #     nearest[c] = labels[c][maxdens_neighbor]
                # else:
                #     graph[c] = []
                #     nearest[c] = labels[c][1]

                # HyperParam 1.8
                nearest_points = np.nonzero(distances[c] <= DISTANCE_CUTOFF * 1.8)[0]
                if nearest_points.size > 1:
                    near_pointdensities = densities[labels[c][nearest_points]]
                    # Maybe applicable only for Toy datasets with counts as densities
                    # if np.all(near_pointdensities == near_pointdensities[0]):
                    #     graph[c] = labels[c][1]
                    #     nearest[c] = labels[c][1]
                    # else:
                    maxdens_neighbor = nearest_points[np.argmax(near_pointdensities)]
                    graph[c] = labels[c][maxdens_neighbor]
                    nearest[c] = labels[c][maxdens_neighbor]
                else:
                    graph[c] = []
                    nearest[c] = labels[c][1]

            continue
        
        max_densityindex = valid_inds[np.argmax(densities[valid_inds])]
        # valid_inds all can have the same density but small. They become separate low-density peaks
        # TODO: have to solve to merge such low-density peaks with higher density peaks
        if c == max_densityindex:
            density_peaks_flag[c] = c
            nearest[c] = c
        else:
            nearest[c] = max_densityindex
            
        graph[c] = list(zip(valid_inds[1:], densities[valid_inds[1:]]))

    print(time.time() - s, 'seconds for graph and densities calculation')
    density_peaks_inds = np.nonzero(density_peaks_flag>=0)[0]
    
    peak_count = int(TOTAL_CONTIGS ** (1-KUPPA))

    if peak_count > len(density_peaks_inds): 
        peak_count = len(density_peaks_inds)
    
    get_inds = np.argsort(densities[density_peaks_inds])[::-1][:peak_count]
    density_peaks = density_peaks_inds[get_inds]

    # plt.figure(figsize=(20,12))
    # plt.plot(densities)

    return densities, density_peaks, graph, nearest

def dijkstra_max_min_density(graph, start):
    """ return max density path between start and nearest neighbors """    
    
    # Initialize maximum minimum densities and priority queue
    max_min_densities = {node: float('-inf') for node in graph}
    max_min_densities[start] = float('inf')
    priority_queue = [(-max_min_densities[start], start)]  # Priority queue of (negative density, node)

    while priority_queue:
        # Extract node with highest maximum minimum density
        current_density, current_node = heapq.heappop(priority_queue)
        current_density = -current_density

        for neighbor, density in graph[current_node]:
            # Calculate new maximum minimum density for neighbor
            new_density = min(current_density, density)
            if new_density > max_min_densities[neighbor]:
                max_min_densities[neighbor] = new_density

                heapq.heappush(priority_queue, (-new_density, neighbor))

    return max_min_densities

def find_connected_components(merge_links):

    K = len(merge_links)
    merge_sets = np.zeros(K).astype(int) - 1
    merge_curr = 0
    for k in np.arange(K):

        if merge_sets[k] < 0:

            candidates = merge_links[k]

            merge_sets[k] = merge_curr

            while len(candidates) != 0:
                l = candidates.pop(-1)

                if merge_sets[l] < 0:
                    merge_sets[l] = merge_curr
                    candidates.extend(merge_links[l])

            merge_curr += 1

    return merge_sets

def remove_outliers_peaks(data, m=2.):
    d = np.abs(data - np.median(data))
    mdev = np.median(d)
    s = d / (mdev if mdev else 1.)
    return np.where(s < m)[0] # indices of outlier points

def density_links(args):
    i, peak, graph, density_peaks, densities = args
    print(peak, i, 'peak and i')
    max_min_densities = dijkstra_max_min_density(graph, peak)
    higherdensity_links = {k:v for k, v in max_min_densities.items() if (k in density_peaks) and (k != peak) and (v != float('-inf'))}
    if higherdensity_links:
        max_key = max(higherdensity_links, key=higherdensity_links.get)
        separability_index = 1 - (higherdensity_links[max_key] / densities[peak])
        merge_peakinds = np.where(np.isin(density_peaks, list(higherdensity_links.keys())))[0]
        merge_link = [i] + list(merge_peakinds)
    else:
        separability_index = 1
        merge_link = []
    return i, separability_index, merge_link


def cluster(latent, contig_length):
    """ cluster contigs """

    valid_indices, labels, distances = get_neighbors(latent, contig_length)
    # print(valid_indices, 'labels and distances')
    densities, density_peaks, graph, nearest = get_density_peaks(valid_indices, labels, distances, contig_length)
    # print(len(density_peaks), 'total density peaks')
    del labels, distances
    separability_indices = np.full(len(density_peaks), 1.0, dtype=np.float32)

    # peak_neighbors = [valid_indices[i].tolist() for i in density_peaks]
    # print(graph, 'graph')
    
    # connected_peakcomponents = group_indices(peak_neighbors)
    
    # # from connected_components import select_subgraph
    # # subgraph = select_subgraph(graph, connected_peakcomponents[0][1])
    # # print(subgraph, 'subgraph')
    # subgraph1 = {node: [adj_node for adj_node in graph[node] if adj_node in connected_peakcomponents[0][1]] for node in connected_peakcomponents[0][1]}
    # print(subgraph1, 'subgraph1')
    # merge_links = []
    # for i, peak in enumerate(density_peaks):
    #     max_min_densities = dijkstra_max_min_density(graph, peak)

    #     higherdensity_links = {k:v for k, v in max_min_densities.items() if (k in density_peaks) and (k != peak) and (v != float('-inf'))}
    #     # higherdensity_links = {k:v for k, v in higherdensity_links.items() if v > densities[peak]} # this doesn't work when v and densities[peak] is same value
    #     if higherdensity_links:
    #         # print(peak, higherdensity_links)
    #         max_key = max(higherdensity_links, key=higherdensity_links.get)
    #         separability_indices[i] = 1 - (higherdensity_links[max_key] / densities[peak])
    #         merge_peakinds = np.where(np.isin(density_peaks, list(higherdensity_links.keys())))[0]
    #         merge_links.append([i] + list(merge_peakinds))
    #     else:
    #         merge_links.append([])

    # # args_list = [(i, peak, graph, density_peaks, densities) for i, peak in enumerate(density_peaks)]
    # # merge_links = []

    # # with Pool(processes=NCPUS) as pool:
    # #     results = pool.map(density_links, args_list)

    # # for i, separability_index, merge_link in results:
    # #     separability_indices[i] = separability_index
    # #     merge_links.append(merge_link)
    
    # # # not useful for deciding number of clusters and assignment as we have merging of density peaks
    # plt.figure(figsize=(16,12))
    # plt.scatter(densities[density_peaks], separability_indices)
    # plt.xlabel('density')
    # plt.ylabel('separability index')

    # separability_indices = np.sort(separability_indices)[::-1]

    # cluster_centercounts = np.argmax(np.abs(np.diff(separability_indices))) + 1
    # if cluster_centercounts == 1:
    #     cluster_centercounts = len(density_peaks)
   
    # # # cluster_centers = density_peaks[densities[density_peaks] > int(TOTAL_CONTIGS ** KUPPA)] # density_peaks[:cluster_centercounts]
    # # ##################################
    
    # merge_sets = find_connected_components(merge_links)

    # return densities, density_peaks, merge_sets, nearest, cluster_centercounts


def assign_points(density_peaks, merge_sets, nearest, densities):
    """ assign points to density peaks and merge closer density peaks """
    nearest[density_peaks] = density_peaks
    
    nearest_prev = np.zeros(nearest.size, dtype=int) - 1

    while (nearest != nearest_prev).any():
        nearest_prev = nearest
        nearest = nearest[nearest[nearest[nearest[:]]]]

    components = []
 
    for k in range(len(density_peaks)):
        
        if len((np.argwhere(nearest == density_peaks[k])[0])) > 0:
            components.append(np.nonzero(nearest == density_peaks[k])[0])

        else:
            raise RuntimeWarning("no clusters in cluster_centers[k] is assigned to nearest[k]")

    clusters = []

    unassigned_points = np.setdiff1d(nearest, density_peaks)
    # print(unassigned_points, labels[unassigned_points], 'assignment step unassigned points')

    for i in set(merge_sets):
        merge_indices = np.nonzero(merge_sets == i)[0]
        merge_temp = []
        for j in merge_indices:
            merge_temp.extend(components[j])
        clusters.append(merge_temp)
    
    # for f in clusters:
    #     print(len(f), sum(densities[f]))
        
    return clusters, components

if __name__ == "__main__":
  
    variable = '/Users/yazhini/Documents/work/'
    latent = np.load(variable + 'latent_mu.npy')
    contig_length = np.load(variable + 'contigs_2klength.npz',allow_pickle=True)['arr_0']
    # contig_names = np.load(variable + 'contigs_2knames.npz',allow_pickle=True)['arr_0']
    otuids = np.loadtxt(variable + 'otuids', dtype='object')
    
    TOTAL_CONTIGS, LATENT_DIMENSION = latent.shape
    ETA = 0.4 # 1 / LATENT_DIMENSION

    print(latent.shape)
    cluster(latent, contig_length)
    # densities, density_peaks, merge_sets, nearest, cluster_centercounts = cluster(latent, contig_length)
    # clusters, components = assign_points(density_peaks, merge_sets, nearest, densities)
    # labels = np.zeros(TOTAL_CONTIGS,dtype=int)

    # counter = 1
    # for i in range(len(clusters)):
    #     labels[clusters[i]] = counter
    #     counter += 1

    # for j in range(len(labels)):
    #     print(labels[j], otuids[j])

    # import seaborn as sns
    # import plotly.express as px
    # fig = px.scatter(x=latent[:,8], y=latent[:,4])
    # fig.update_traces(marker=dict(size=1))
    # fig.show()



(159956, 32)
0.014136075973510742 self index adding
19052.436 138.03056 distance cutoff
4.717833042144775 seconds for graph and densities calculation
