# Project 01 - Color Compression

## Thông tin sinh viên

- Họ và tên: Tô Quốc Thanh
- MSSV: 22127388
- Lớp: 22CLC10

## Import các thư viện liên quan

In [843]:
import numpy as np
import matplotlib.pyplot as plt 
import matplotlib.image as img 

## Helper functions

In [1105]:
def read_img(img_path):
    return img.imread(img_path)

def show_img(img_2d):
    plt.imshow(img_2d)

def save_img(img_2d, img_path):
    plt.imsave(img_path, img_2d)

def convert_img_to_1d(img_2d):
    new_img = [elements for rows in img_2d for elements in rows]
    return new_img

def kmeans(img_1d, k_clusters, max_iter, init_centroids='in_pixels'):
    img_1d = np.array(img_1d, dtype = 'int32') 
    
    # Initialize centroids
    if init_centroids == 'random':
        centroids = np.random.randint(0, 256, size=(k_clusters, img_1d.shape[1]))
    elif init_centroids == 'in_pixels':
        unique_colors = np.unique(img_1d, axis=0)
        random_index = np.random.choice(unique_colors.shape[0], k_clusters, replace=False)
        centroids = unique_colors[random_index]

    labels = np.zeros(img_1d.shape[0], dtype = 'int32')

    for _ in range(max_iter):
        distances = np.linalg.norm(img_1d[:, np.newaxis] - centroids, axis=2)
        new_labels = np.argmin(distances, axis=1)

        if np.array_equal(new_labels, labels):
            break
        labels = new_labels

    # Update centroids
    for k in range(k_clusters):
        if np.any(labels == k):
            centroids[k] = img_1d[labels == k].mean(axis=0)
        else:
            centroids[k] = img_1d.mean(axis=0) 
            
    return centroids, labels

def generate_2d_img(img_2d_shape, centroids, labels):
    new_img = centroids[labels].reshape(img_2d_shape)
    return np.array(new_img, dtype='uint8')

# Your additional functions here
def is_valid_image_extension(file_name):
    valid_extensions = ['.png', '.jpg', '.jpeg','.pdf']
    return any(file_name.lower().endswith(e) for e in valid_extensions)

## Your tests

In [None]:
# img_path = 'img.jpg'  
# img_2d = read_img(img_path)
# img_1d = convert_img_to_1d(img_2d)
# k_clusters = 7
# max_iter = 100
# tmp = kmeans(img_1d, k_clusters, max_iter,'in_pixels')
# new_img_2d = generate_2d_img(img_2d.shape, tmp[0], tmp[1])
# show_img(new_img_2d)

## Main FUNCTION

In [1107]:
# YOUR CODE HERE
def main():
    while True:
        img_path = input("Enter the path to the image (including filename and extension): ")
        try:
            img_2d = read_img(img_path)
            break
        except FileNotFoundError:
            print("File does not exist. Please enter a valid file path.")

    img_1d = convert_img_to_1d(img_2d)
    
    while True:
        try:
            k_clusters = int(input("Enter the number of clusters: "))
            if k_clusters <= 0:
                raise ValueError
            break
        except ValueError:
            print("Invalid input. Please enter a positive integer for the number of clusters.")
    
    while True:
        try:
            max_iter = int(input("Enter the maximum number of iterations: "))
            if max_iter <= 0:
                raise ValueError
            break
        except ValueError:
            print("Invalid input. Please enter a positive integer for the maximum number of iterations.")
    
    while True:
        init_centroids = input("Enter the method to initialize centroids (random/in_pixels): ")
        if init_centroids not in ['random', 'in_pixels']:
            print("Invalid input. Please enter 'random' or 'in_pixels'.")
            continue
        break
    centroids, labels = kmeans(img_1d, k_clusters, max_iter,init_centroids)
    new_img_2d = generate_2d_img(img_2d.shape, centroids, labels)
    show_img(new_img_2d)
    
    while True:
        save_path = input("Enter the path to save the segmented image (including filename and extension): ")
        if not is_valid_image_extension(save_path):
            print("Invalid file extension. Please enter an save path with one of the following extensions: .png, .jpg, .jpeg, .bmp, .tiff, .pdf")
            continue
        try:
            save_img(new_img_2d, save_path)
            break
        except IOError as e:
            print(e)
            
    print("Successfully saved the image!")

In [None]:
# Call main function
if __name__ == "__main__":
    main()