In [6]:
# %pip install numpy matplotlib scikit-learn imageio
# %pip install rasterio
# %pip install -U gdal

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from sklearn.cluster import KMeans
import imageio

def cluster_image(image_path, n_clusters=3):
    # Read the image
    img = imageio.imread(image_path)
    
    # Reshape the image to be a 2D array where each row is a pixel and each column is a color channel
    pixel_values = img.reshape((-1, 3))

    # Convert to float
    pixel_values = np.float32(pixel_values)

    # Define and fit the KMeans model
    kmeans = KMeans(n_clusters=n_clusters)
    kmeans.fit(pixel_values)

    # Predict the cluster for each pixel and reshape it back to the image dimensions
    labels = kmeans.predict(pixel_values)
    segmented_img = labels.reshape(img.shape[0], img.shape[1])
    imageio.imwrite('mask.tif', segmented_img.astype(np.uint8))
    # Plotting
    plt.imshow(segmented_img)
    plt.title(f'Image clustered into {n_clusters} colors')
    plt.axis('off')
    plt.show()

# Replace 'path_to_your_image.tif' with your image file path
cluster_image('naip.tif', n_clusters=3)


In [10]:
import numpy as np
import rasterio
from sklearn.cluster import KMeans
import imageio

def cluster_image(image_path, n_clusters=3, output_path='clustered_image.tif'):
    # Read the image using rasterio to access geospatial metadata
    with rasterio.open(image_path) as src:
        img = src.read()
        transform = src.transform
        crs = src.crs
        
        # Reshape for KMeans
        pixel_values = img.reshape((3, -1)).T  # Transpose to have channels as columns
        pixel_values = np.float32(pixel_values)
        
        # Apply KMeans clustering
        kmeans = KMeans(n_clusters=n_clusters)
        kmeans.fit(pixel_values)
        labels = kmeans.labels_
        
        # Reshape labels back to image dimensions
        segmented_img = labels.reshape((src.height, src.width))
        
        # Save the clustered image with georeference
        with rasterio.open(
            output_path,
            'w',
            driver='GTiff',
            height=segmented_img.shape[0],
            width=segmented_img.shape[1],
            count=1,
            dtype=segmented_img.dtype,
            crs=crs,
            transform=transform
        ) as dst:
            dst.write(segmented_img, 1)

# Usage
cluster_image('naip.tif', n_clusters=7, output_path='mask7.tif')


  super()._check_params_vs_input(X, default_n_init=10)


In [3]:
import rasterio