In [None]:
#| default_exp clustering.cluster

In [None]:
#| export
import torch.nn.functional as F, scipy.sparse as sp, numpy as np, functools, operator, torch, time, sys, gc, os
from sklearn.preprocessing import normalize
from multiprocessing import Pool
from torch.utils.data import Sampler
from typing import Optional, List, Union, Any

from xcai.core import *
from xcai.clustering.fast_cluster import balanced_cluster, next_power_of_two

from fastcore.dispatch import *
from fastcore.basics import *

In [None]:
#| hide
from nbdev.showdoc import *
import nbdev; nbdev.nbdev_export()

## `BalancedClusters`: CLUSTERING

In [None]:
#| export
def b_kmeans_dense_multi(fts_lbl, index, tol=1e-4):
    lbl_cent = normalize(np.squeeze(fts_lbl[:, 0, :]))
    lbl_fts = normalize(np.squeeze(fts_lbl[:, 1, :]))
    if lbl_cent.shape[0] == 1:
        return [index]
    cluster = np.random.randint(low=0, high=lbl_cent.shape[0], size=(2))
    while cluster[0] == cluster[1]:
        cluster = np.random.randint(low=0, high=lbl_cent.shape[0], size=(2))
    _centeroids = lbl_cent[cluster]
    _sim = np.dot(lbl_cent, _centeroids.T)
    old_sim, new_sim = -1000000, -2
    while new_sim - old_sim >= tol:
        c_lbs = np.array_split(np.argsort(_sim[:, 1]-_sim[:, 0]), 2)
        _centeroids = normalize(np.vstack([
            np.mean(lbl_cent[x, :], axis=0) for x in c_lbs
        ]))
        _sim_1 = np.dot(lbl_cent, _centeroids.T)
        _centeroids = normalize(np.vstack([
            np.mean(lbl_fts[x, :], axis=0) for x in c_lbs
        ]))
        _sim_2 = np.dot(lbl_fts, _centeroids.T)
        _sim = _sim_1 + _sim_2
        old_sim, new_sim = new_sim, np.sum([np.sum(_sim[c_lbs[0], 0]),
                                            np.sum(_sim[c_lbs[1], 1])])
    return list(map(lambda x: index[x], c_lbs))


def b_kmeans_dense(labels_features, index, tol=1e-4, *args, **kwargs):
    if labels_features.shape[0] == 1:
        return [index]
    cluster = np.random.randint(low=0, high=labels_features.shape[0], size=(2))
    while cluster[0] == cluster[1]:
        cluster = np.random.randint(
            low=0, high=labels_features.shape[0], size=(2))
    _centeroids = labels_features[cluster]
    _similarity = np.dot(labels_features, _centeroids.T)
    old_sim, new_sim = -1000000, -2
    while new_sim - old_sim >= tol:
        sim_diff = _similarity[:, 1] - _similarity[:, 0]
        sim_diff_idx = np.argsort(sim_diff)
        clustered_lbs = np.array_split(sim_diff_idx, 2)
        c_l = np.mean(labels_features[clustered_lbs[0], :], axis=0)
        c_r = np.mean(labels_features[clustered_lbs[1], :], axis=0)
        _centeroids = normalize(np.vstack([c_l, c_r]))
        _similarity = np.dot(labels_features, _centeroids.T)
        s_l = np.sum(_similarity[clustered_lbs[0], 0])
        s_r = np.sum(_similarity[clustered_lbs[1], 1])
        old_sim, new_sim = new_sim, s_l + s_r
    return list(map(lambda x: index[x], clustered_lbs))


def b_kmeans_sparse(labels_features, index, tol=1e-4, *args, **kwargs):
    def _sdist(XA, XB):
        return XA.dot(XB.transpose())
    labels_features = normalize(labels_features)
    if labels_features.shape[0] == 1:
        return [index]
    cluster = np.random.randint(low=0, high=labels_features.shape[0], size=(2))
    while cluster[0] == cluster[1]:
        cluster = np.random.randint(
            low=0, high=labels_features.shape[0], size=(2))
    _centeroids = normalize(labels_features[cluster].todense())
    _sim = _sdist(labels_features, _centeroids)
    old_sim, new_sim = -1000000, -2
    while new_sim - old_sim >= tol:
        c_lbs = np.array_split(np.argsort(_sim[:, 1]-_sim[:, 0]), 2)
        _centeroids = normalize(np.vstack([
            labels_features[x, :].mean(axis=0) for x in c_lbs]))
        _sim = _sdist(labels_features, _centeroids)
        old_sim, new_sim = new_sim, np.sum([
            np.sum(_sim[c_lbs[0], 0]), np.sum(_sim[c_lbs[1], 1])])
    return list(map(lambda x: index[x], c_lbs))


def b_kmeans_dense_gpu(labels_features, index, tol=1e-4, use_cuda=False):
    if use_cuda:
        labels_features = labels_features.cuda()
    if labels_features.shape[0] == 1:
        return [index]
    cluster = np.random.randint(low=0, high=labels_features.shape[0], size=(2))
    while cluster[0] == cluster[1]:
        cluster = np.random.randint(
            low=0, high=labels_features.shape[0], size=(2))
    _centeroids = labels_features[cluster]
    _similarity = torch.mm(labels_features, _centeroids.T)
    old_sim, new_sim = -1000000, -2
    while new_sim - old_sim >= tol:
        sim_diff = _similarity[:, 1]-_similarity[:, 0]
        sim_diff_idx = np.argsort(sim_diff.cpu().numpy())
        clustered_lbs = np.array_split(sim_diff_idx, 2)
        c_l = torch.mean(labels_features[clustered_lbs[0], :], dim=0)
        c_r = torch.mean(labels_features[clustered_lbs[1], :], dim=0)
        _centeroids = F.normalize(torch.stack([c_l, c_r], dim=0))
        _similarity = torch.mm(labels_features, _centeroids.T)
        s_l = torch.sum(_similarity[clustered_lbs[0], 0]).item()
        s_r = torch.sum(_similarity[clustered_lbs[1], 1]).item()
        old_sim, new_sim = new_sim, s_l+s_r
    labels_features = labels_features.cpu()
    del labels_features
    gc.collect()
    return list(map(lambda x: index[x], clustered_lbs))


In [None]:
#| export
def get_functions(mat):
    if torch.is_tensor(mat):
        print("Using GPU for clustering")
        return b_kmeans_dense_gpu
    if isinstance(mat, np.ndarray):
        if len(mat.shape) == 3:
            print("Using dense kmeans++ for multi-view")
            return b_kmeans_dense_multi
        elif len(mat.shape) == 2:
            print("Using dense kmeans++")
            return b_kmeans_dense
    elif sp.issparse(mat):
        print("Using sparse kmeans++")
        return b_kmeans_sparse
    print("dtype not understood!!")
    exit(0)


def _normalize(mat):
    if torch.is_tensor(mat):
        return mat
    elif isinstance(mat, np.ndarray) or sp.issparse(mat):
        return normalize(mat)
    else:
        raise TypeError(f"{type(mat)} is not supported")



In [None]:
#| export
def cluster(labels, max_leaf_size=None, min_splits=16, num_workers=4,
            return_smat=False, num_clusters=None, force_gpu=False):
    num_nodes = num_clusters
    if num_nodes is None:
        num_nodes = np.ceil(np.log2(labels.shape[0]/max_leaf_size))
        num_nodes = int(2**num_nodes)
    group = [np.arange(labels.shape[0])]
    labels = _normalize(labels)
    if force_gpu:
        labels = torch.from_numpy(labels).type(torch.FloatTensor)
    else:
        labels = np.array(labels.cpu(), dtype=np.float32)
    splitter = get_functions(labels)
    min_singe_thread_split = min(min_splits, num_nodes)
    if min_singe_thread_split < 1:
        if torch.is_tensor(labels):
            labels = labels.cuda()
    print(f"Max leaf size {max_leaf_size}")
    print(f"Total number of group are {num_nodes}")
    print(f"Average leaf size is {labels.shape[0]/num_nodes}")
    start = time.time()

    def splits(flag, labels, group):
        if flag or torch.is_tensor(labels):
            return map(lambda x: splitter(labels[x], x, use_cuda=not flag), group)
        else:
            with Pool(num_workers) as p:
                mapps = p.starmap(splitter, map(
                    lambda x: (labels[x], x, flag), group))
            return mapps

    def print_stats(group, end="\n", file=sys.stdout):
        string = f"Total groups {len(group)}"
        string += f", Avg. group size {np.mean(list(map(len, group)))}"
        string += f", Total time {time.time()-start} sec."
        print(string, end=end, file=file)

    while len(group) < num_nodes:
        print_stats(group, "\r", sys.stderr)
        flags = len(group) < min_singe_thread_split
        group = functools.reduce(operator.iconcat,
                                 splits(flags, labels, group), [])
    print_stats(group)
    if return_smat:
        cols = np.uint32(np.concatenate(
            [[x]*len(y) for x, y in enumerate(group)]))
        rows = np.uint32(np.concatenate(group))
        group = sp.lil_matrix((labels.shape[0], np.int32(num_nodes)))
        group[rows, cols] = 1
        group = group.tocsr()
    del labels
    return group


def partial_cluster(
    embs_bank: torch.Tensor,
    min_leaf_sz: int,
    num_random_clusters: int,
    clustering_devices: Optional[List]=None,
    ):
    if not isinstance(embs_bank, torch.Tensor): raise ValueError('`embs_bank` should be `torch.Tensor`')
    embs = embs_bank.clone()
    tree_depth = int(np.ceil(np.log(embs.shape[0] / min_leaf_sz) / np.log(2)))
    print(f"Updating clusters with size {min_leaf_sz}")
    print(f"Tree depth = {tree_depth}")

    if clustering_devices is None:
        clustering_devices = (
            np.arange(len(os.getenv("CUDA_VISIBLE_DEVICES").split(','))) 
            if os.getenv("CUDA_VISIBLE_DEVICES") is not None else 
            np.arange(torch.cuda.device_count())
        )
        
    num_random_clusters = (
        num_random_clusters
        if num_random_clusters != -1
        else next_power_of_two(len(clustering_devices))
    )
    if num_random_clusters < len(clustering_devices):
                print("num_random_clusters provided is less \
                    than number of clustring devices which is not optimal")
                
    clusters = balanced_cluster(torch.HalfTensor(embs.half()),
                                tree_depth,
                                clustering_devices,
                                num_random_clusters,
                                True)
    
    del embs
    gc.collect()
    
    return clusters


In [None]:
#| export
class BalancedClusters:

    @staticmethod
    def proc(x:torch.Tensor, min_cluster_sz:int, clustering_devices:Optional[List]=None, verbose:Optional[bool]=True):
        return partial_cluster(x, min_cluster_sz, -1, clustering_devices)
        

### Example

In [None]:
x = torch.randn(10, 10)

In [None]:
%time clusters = BalancedClusters.proc(x, 3, use_fast_clustering=True)

Updating clusters with size 3
Tree depth = 2
doing random split
lengths: [5, 5]
remaining levels for GPU split=1
==> gpu splitting random clusters 0 to 2
 rank=1 => Total clusters 2	Avg. Cluster size                 2.50	Time to split nodes on this level 1.11 sec
 rank=0 => Total clusters 2	Avg. Cluster size                 2.50	Time to split nodes on this level 1.10 sec

CPU times: user 197 ms, sys: 128 ms, total: 325 ms
Wall time: 8.49 s


In [None]:
clusters

[array([0, 9, 2]), array([1, 7]), array([3, 5, 8]), array([4, 6])]

In [None]:
#%time clusters = BalancedClusters.proc(x, 5, use_fast_clustering=False)

## `ClusterGroupedSampler`: CLUSTER BASED SAMPLING

In [None]:
#| export
class ClusterGroupedSampler(Sampler):

    def __init__(self, n:int, cluster:Optional[List]=None, generator:Optional[Any]=None):
        store_attr('n,cluster,generator')

    def __len__(self):
        return self.n

    def set_cluster(self, cluster): self.cluster = cluster

    def __iter__(self):
        if self.cluster is None: return iter(torch.randperm(self.n).tolist())
        csz = sum([len(o) for o in self.cluster])
        if len(self) != csz: raise ValueError(f'`n`({len(self)}) should be equal to total elements in `cluster`({csz})')
        
        cluster = [self.cluster[i] for i in np.random.permutation(len(self.cluster))]
        if isinstance(cluster[0], torch.Tensor):
            indices = torch.hstack([o[torch.randperm(len(o))] for o in cluster]).tolist()
        else: indices = np.hstack([o[np.random.permutation(len(o))] for o in cluster]).tolist()
        
        return iter(indices)
        

### Example

In [None]:
from torch.utils.data import DataLoader

In [None]:
x = torch.randn(16, 3)

In [None]:
cluster = BalancedClusters.proc(x, 4, [0])

Updating clusters with size 4
Tree depth = 2
doing cpu split
remaining levels for GPU split=2
==> gpu splitting random clusters 0 to 1
 rank=0 => Total clusters 2	Avg. Cluster size                 8.00	Time to split nodes on this level 0.78 sec
 rank=0 => Total clusters 4	Avg. Cluster size                 4.00	Time to split nodes on this level 0.01 sec



In [None]:
cluster

[array([ 3,  1, 15, 14]),
 array([ 2,  5, 11, 10]),
 array([ 9,  8, 12,  4]),
 array([ 7, 13,  6,  0])]

In [None]:
sampler = ClusterGroupedSampler(16)
dl = DataLoader(torch.arange(len(x)), batch_size=5, sampler=sampler)

In [None]:
dl.sampler.set_cluster(cluster)

In [None]:
[o for o in dl]

[tensor([13,  6,  0,  7,  5]),
 tensor([ 2, 10, 11,  3,  1]),
 tensor([14, 15, 12,  8,  4]),
 tensor([9])]

## `Cluster mapping`

In [None]:
#| export
def get_cluster_mapping(embeddings:torch.Tensor, cluster_sz:int=3):
    clusters = BalancedClusters.proc(embeddings.half(), min_cluster_sz=cluster_sz)

    cluster_mapping = torch.zeros(embeddings.shape[0], dtype=torch.int64)
    for i,o in enumerate(clusters): cluster_mapping[o] = i
    return cluster_mapping, len(clusters)
    

In [None]:
#| export
def get_cluster_size(emb_sz, cluster_sz):
    return 2**int(np.ceil(np.log2(emb_sz / cluster_sz)))
    