In [2]:
import numpy as np
from utils import *
import faiss
import random
import time
from tqdm import tqdm
from multiprocessing import Pool, cpu_count
from collections import Counter

In [22]:
anno_files_list = get_all_anno_files(".")
anno_bit_vectors = load_all_annotation_bit_vectors(anno_files_list).T
print(anno_bit_vectors.shape) # [rows, cols]
packed_anno_bit_vectors = np.packbits(anno_bit_vectors, axis=1)
packed_anno_bit_vectors = np.ascontiguousarray(packed_anno_bit_vectors)
print(packed_anno_bit_vectors.shape) # [rows, cols/8]

(60801408, 32)
(60801408, 4)


In [33]:
def compute_centroid(cluster):
    """
    Compute the centroid (bitwise majority) of a cluster of dim-bit binary codes (dim/8 uint8).
    
    Args:
        cluster (list of np.ndarray): List of (dim/8,) binary codes as `uint8` arrays.
        
    Returns:
        np.ndarray: Centroid as a (dim/8,) `uint8` array.
    """
    if not cluster:
        return None
    
    # Convert cluster to a 2D numpy array (each row is a uint8 array)
    cluster_bits = np.array(cluster, dtype=np.uint8)
    
    # Compute bitwise majority for each bit in the uint8 arrays
    # First, count the number of 1s per bit position
    # Since each uint8 has 8 bits, we need to count these bitwise
    majority_bits = np.zeros_like(cluster_bits[0], dtype=np.uint8)
    for i in range(8):  # Iterate over all 8 bits in each uint8
        bit_counts = (cluster_bits >> i) & 1  # Isolate the i-th bit for all elements
        majority_bit = (bit_counts.sum(axis=0) >= (len(cluster) / 2)).astype(np.uint8)  # Majority vote
        majority_bits |= (majority_bit << i)  # Set the i-th bit in the centroid
    
    return majority_bits


def compute_hamming_distances(args):
    """
    Compute the Hamming distances between a point and all centroids.
    
    Args:
        args (tuple): (point, centroids)
        
    Returns:
        tuple: (point, index of closest centroid, distance to closest centroid)
    """
    point, centroids = args
    distances = hamming_distances_1_to_n(point, np.array(centroids))
    cluster_idx = np.argmin(distances)
    return point, cluster_idx

def hamming_distances_1_to_n(point, array):
    """
    point: numpy array of uint8 (dim/8)
    array: numpy array of uint8 (m, dim/8)
    
    return: numpy array of uint8 (m)
    """
    return np.unpackbits(np.bitwise_xor(point, array), axis=1).sum(axis=1)


def mini_batch_kmeans_hamming(dataset, k, batch_size=1000, max_iters=10):
    """
    Perform Mini-Batch K-means clustering using Hamming distance for 64-bit binary codes.
    
    Args:
        dataset (list of int): List of 64-bit binary codes as integers.
        k (int): Number of clusters.
        batch_size (int): Size of each mini-batch.
        max_iters (int): Maximum number of iterations.
        
    Returns:
        list: Cluster assignments for each data point.
        list: Final centroids for each cluster.
    """
    n = len(dataset)
    
    # Step 1: Initialize centroids (randomly select k points from the dataset)
    # dataset is (n, dim/8)
    # random select k points from the dataset
    rng = np.random.default_rng()
    centroids = rng.choice(dataset, k, replace=False)
    
    # Initialize cluster assignments
    cluster_assignments = [None] * n
    cluster_counts = [0] * k  # Track number of points assigned to each cluster
    print("Initialization complete.")
    
    # Step 2: Iterate until convergence or max iterations
    for iteration in tqdm(range(max_iters)):
        # Step 2.1: Sample a mini-batch from the dataset
        mini_batch = dataset[rng.choice(n, batch_size, replace=False)]
        
        # Step 2.2: Assign points in the mini-batch to the nearest centroid
        clusters = [[] for _ in range(k)]
        
        # Use multiprocessing to compute distances in parallel
        with Pool(cpu_count()) as pool:
            results = list(pool.imap(compute_hamming_distances, [(point, centroids) for point in mini_batch]))
        # Single-threaded version
        # results = [compute_hamming_distances((point, centroids)) for point in mini_batch]
        
        # Update cluster assignments and counts based on mini-batch results
        for point, cluster_idx in results:
            clusters[cluster_idx].append(point)
            cluster_counts[cluster_idx] += 1
        
        # Step 2.3: Update centroids incrementally using the mini-batch
        new_centroids = centroids.copy()
        for cluster_idx, cluster in enumerate(clusters):
            if cluster:
                cluster_centroid = compute_centroid(cluster)
                
                # Incrementally update centroid using a weighted average
                if cluster_counts[cluster_idx] > 0:
                    new_centroids[cluster_idx] = (
                        (cluster_counts[cluster_idx] - len(cluster)) * centroids[cluster_idx] + cluster_centroid
                    ) // cluster_counts[cluster_idx]
        
        # Check for convergence (if centroids do not change)
        if np.array_equal(centroids, new_centroids):
            print("Convergence reached.")
            break
        
        centroids = new_centroids
    
    # Final assignment: Assign all points in the dataset to the nearest centroid
    with Pool(cpu_count()) as pool:
        final_results = list(tqdm(pool.imap(compute_hamming_distances, [(point, centroids) for point in dataset]), total=n))
    
    for i, (point, cluster_idx) in enumerate(final_results):
        cluster_assignments[i] = cluster_idx  # Assign cluster index directly by position
        
    return cluster_assignments, centroids


# Perform Mini-Batch K-means clustering
cluster_assignments, centroids = mini_batch_kmeans_hamming(packed_anno_bit_vectors, k=50000, batch_size=10000, max_iters=100)
# Save the cluster assignments and centroids
np.save("cluster_assignments.npy", np.array(cluster_assignments))

Initialization complete.


 99%|█████████▉| 99/100 [07:45<00:04,  4.72s/it]

: 

In [12]:
def hamming_distances(x, y):
    """
    x: numpy array of uint8 (dim/8)
    y: numpy array of uint8 (m, dim/8)
    
    return: numpy array of uint8 (m)
    """
    return np.unpackbits(np.bitwise_xor(x, y), axis=1).sum(axis=1)

# Example usage of hammind_distances
x = np.array([0b00000000, 0b11111111], dtype=np.uint8)
y = np.array([[0b11111111, 0b00000000], [0b11001100, 0b00110011], [0b11001111, 0b00110011]], dtype=np.uint8)
print(x.shape)
print(y.shape)
print(hamming_distances(x, y))

(2,)
(3, 2)
[16  8 10]


In [20]:
import numpy as np

def compute_centroid(cluster):
    """
    Compute the centroid (bitwise majority) of a cluster of dim-bit binary codes (dim/8 uint8).
    
    Args:
        cluster (list of np.ndarray): List of (dim/8,) binary codes as `uint8` arrays.
        
    Returns:
        np.ndarray: Centroid as a (dim/8,) `uint8` array.
    """
    if not cluster:
        return None
    
    # Convert cluster to a 2D numpy array (each row is a uint8 array)
    cluster_bits = np.array(cluster, dtype=np.uint8)
    
    # Compute bitwise majority for each bit in the uint8 arrays
    # First, count the number of 1s per bit position
    # Since each uint8 has 8 bits, we need to count these bitwise
    majority_bits = np.zeros_like(cluster_bits[0], dtype=np.uint8)
    for i in range(8):  # Iterate over all 8 bits in each uint8
        bit_counts = (cluster_bits >> i) & 1  # Isolate the i-th bit for all elements
        majority_bit = (bit_counts.sum(axis=0) >= (len(cluster) / 2)).astype(np.uint8)  # Majority vote
        majority_bits |= (majority_bit << i)  # Set the i-th bit in the centroid
    
    return majority_bits


# Example usage of compute_centroid
cluster = [
    np.array([0b10101010, 0b01010101], dtype=np.uint8),
    np.array([0b10101010, 0b01010101], dtype=np.uint8),
    np.array([0b00000001, 0b11111111], dtype=np.uint8),
    np.array([0b01111111, 0b00000000], dtype=np.uint8),
]

print(compute_centroid(cluster))  # Expected output: [0b10101010, 0b01010101]
# print as binary
print(f"{compute_centroid(cluster)[0]:08b}")
print(f"{compute_centroid(cluster)[1]:08b}")

[171  85]
10101011
01010101


In [12]:
0b01010101

85