In [1]:
import cv2
import numpy as np
import matplotlib.pyplot as plt
import time


def kmeans_plus_plus_init(X, k):
    """Initialize centroids using K-Means++."""
    N, _ = X.shape
    centroids = np.zeros((k, X.shape[1]))
    
    # 1. Select the first centroid randomly
    idx = np.random.randint(0, N)
    centroids[0] = X[idx]
    
    # 2. Select the remaining centroids
    for i in range(1, k):
        # Compute distance of each point to the nearest centroid
        dists = np.min(np.linalg.norm(X[:, np.newaxis, :] - centroids[:i][np.newaxis, :, :], axis=2), axis=1)
        probs = dists**2
        probs /= probs.sum()
        # Choose next centroid with probability proportional to distance
        idx = np.random.choice(N, p=probs)
        centroids[i] = X[idx]
    
    return centroids


def kmeans_quantization(image, k=16, max_iterations=20, tol=1e-5, sample_fraction=0.1):
    """Perform K-means color quantization with sampling + K-Means++ initialization."""
    H, W, C = image.shape
    X = image.reshape(-1, 3)  # Nx3
    N = X.shape[0]

    # --- Subsampling for speed ---
    np.random.seed(42)
    sample_size = int(N * sample_fraction)
    sample_idx = np.random.choice(N, sample_size, replace=False)
    X_sample = X[sample_idx]

    # --- K-Means++ initialization on the sample ---
    centroids = kmeans_plus_plus_init(X_sample, k)

    l2_norms = []

    for iteration in range(max_iterations):
        # 1. Assign clusters on the sample
        dists = np.linalg.norm(X_sample[:, np.newaxis, :] - centroids[np.newaxis, :, :], axis=2)
        labels = np.argmin(dists, axis=1)

        # 2. Update centroids
        new_centroids = np.zeros_like(centroids)
        for i in range(k):
            if np.any(labels == i):
                new_centroids[i] = X_sample[labels == i].mean(axis=0)
            else:
                new_centroids[i] = X_sample[np.random.randint(0, sample_size)]
        centroids = new_centroids

        # 3. Compute L2 norm (on the sample)
        l2_norm = np.sqrt(np.sum((X_sample - centroids[labels])**2))
        l2_norms.append(l2_norm)

        # 4. Convergence check
        if iteration > 0 and abs(l2_norms[-2] - l2_norms[-1]) < tol:
            break

    # --- Apply centroids to the full image ---
    dists_full = np.linalg.norm(X[:, np.newaxis, :] - centroids[np.newaxis, :, :], axis=2)
    labels_full = np.argmin(dists_full, axis=1)
    quantized_image = centroids[labels_full].reshape(H, W, 3)

    return quantized_image, l2_norms


def calculate_l2_norm(original, quantized):
    return np.sqrt(np.sum((original - quantized)**2))


def process_image(input_path, k=16, max_iterations=20, visualize=True, log_file="L2_norm_log.txt"):
    image = cv2.imread(input_path)
    if image is None:
        raise FileNotFoundError(f"Image not found: {input_path}")
    image = image.astype(np.float32) / 255.0

    start_time = time.time()
    quantized, l2_norms = kmeans_quantization(image, k, max_iterations)
    exec_time = time.time() - start_time

    # Save quantized image
    output_path = input_path.split('.')[0] + "_quantized.png"
    cv2.imwrite(output_path, (quantized * 255).astype(np.uint8))

    # Save L2 norms log
    with open(log_file, 'a') as f:
        f.write(f"{input_path} L2 Norms:\n")
        f.write('\n'.join(map(str, l2_norms)) + '\n\n')

    # Calculate total L2 norm
    total_l2 = calculate_l2_norm(image, quantized)

    if visualize:
        plt.imshow(cv2.cvtColor((image * 255).astype(np.uint8), cv2.COLOR_BGR2RGB))
        plt.title("Original Image")
        plt.axis("off")
        plt.show()

        plt.imshow(cv2.cvtColor((quantized * 255).astype(np.uint8), cv2.COLOR_BGR2RGB))
        plt.title("Quantized Image")
        plt.axis("off")
        plt.show()

    print(f"[{input_path}] Total L2 Norm: {total_l2:.2f}, Execution time: {exec_time:.2f} sec")
    return total_l2, exec_time


if __name__ == "__main__":
    input_path = "lena.png" 
    k = 16
    max_iterations = 20
    visualize = True
    log_file = "L2_norm_log.txt"

    # Clear log file at start
    open(log_file, 'w').close()
    process_image(input_path, k, max_iterations, visualize, log_file)


FileNotFoundError: Image not found: lena.png