In [1]:
import pandas as pd
import numpy as np
from sklearn.cluster import KMeans
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import GridSearchCV
import matplotlib.pyplot as plt
import matplotlib.image as mpimg

# Load complexity data
complexity_df = pd.read_csv('feature_hm.csv')

# Select features and labels
features = complexity_df.drop(['Image Filename', 'Complexity Category'], axis=1)

# Feature Scaling
scaler = StandardScaler()
features_scaled = scaler.fit_transform(features)

# Hyperparameter Tuning
param_grid = {'n_clusters': [2, 3, 4, 5, 6]}  # Try different numbers of clusters
kmeans = KMeans()
grid_search = GridSearchCV(kmeans, param_grid, cv=5)
grid_search.fit(features_scaled)
best_n_clusters = grid_search.best_params_['n_clusters']

# Apply K-means Clustering with the best number of clusters
kmeans = KMeans(n_clusters=best_n_clusters)
kmeans.fit(features_scaled)
cluster_labels = kmeans.labels_

# User Input and Querying the Cluster
length = int(input("Enter length: "))
width = int(input("Enter width: "))
complexity = input("Enter complexity (high, medium, low): ")

# Map user input to integer label
complexity_mapping = {'high': 0, 'medium': 1, 'low': 2}
user_complexity_encoded = complexity_mapping[complexity]

# Scale user input features
user_input_scaled = scaler.transform(np.array([[length, width] + [0] * (features.shape[1] - 2)]))

# Predict the cluster label for user input
predicted_cluster = kmeans.predict(user_input_scaled)[0]

# Display Results
print("Images in the predicted cluster:")
images_in_cluster = []
for image_info, cluster_label in zip(complexity_df['Image Filename'], cluster_labels):
    if cluster_label == predicted_cluster:
        images_in_cluster.append(image_info)

num_images = len(images_in_cluster)
num_rows = (num_images + 2) // 3  # Round up to the nearest multiple of 3

fig, axes = plt.subplots(num_rows, 3, figsize=(15, num_rows * 5))

for i, image_info in enumerate(images_in_cluster):
    img_path = 'images/' + image_info  # Assuming the images are in a folder named 'images'
    img = mpimg.imread(img_path)
    ax = axes[i // 3, i % 3]
    ax.imshow(img)
    ax.axis('off')
    ax.set_title(image_info, fontsize=10)
    
# Remove empty subplots if the number of images is not a multiple of 3
if num_images % 3 != 0:
    for j in range(num_images % 3, 3):
        fig.delaxes(axes[-1, j])

plt.tight_layout()
plt.show()

FileNotFoundError: [Errno 2] No such file or directory: 'feature_hm2.csv'