In [None]:
import os
import glob
import json

import torch

from sklearn.metrics import silhouette_score
from scipy.spatial.distance import cdist


DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
# ===== K-MEANS BASE =====
@torch.jit.script
def calculate_distances(data: torch.Tensor, centers: torch.Tensor) -> torch.Tensor:
    return torch.cdist(data, centers)

@torch.jit.script
def assign_clusters(distances: torch.Tensor) -> torch.Tensor:
    return torch.argmin(distances, dim=1)

@torch.jit.script
def calculate_sse(data: torch.Tensor, centers: torch.Tensor, cluster_labels: torch.Tensor) -> float:
    sse = torch.sum(torch.stack([
        torch.sum((data[cluster_labels == i] - centers[i]) ** 2)
        for i in range(len(centers))
    ]))
    return sse.item()

def initialize_kmeans_centers(data: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
    # print("Initializing centers ...", end=' - ')
    n = len(data)
    sqrt_n = int(torch.sqrt(torch.tensor(n)))
    centers = []

    while len(centers) < sqrt_n:
        sse_min = float('inf')
        new_center = None
        final_labels = None

        for i in range(n):
            if any(torch.equal(data[i], center) for center in centers):
                continue

            temp_centers = torch.stack(centers + [data[i]])
            distances = calculate_distances(data, temp_centers)
            cluster_labels = assign_clusters(distances)
            sse = calculate_sse(data, temp_centers, cluster_labels)

            if sse < sse_min:
                sse_min = sse
                new_center = data[i]
                final_labels = cluster_labels.clone()

        centers.append(new_center)

    return final_labels, torch.stack(centers)