In [None]:
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.keras import layers, models
from sklearn.cluster import KMeans
from sklearn.metrics import silhouette_score
from sklearn.manifold import TSNE
from sklearn.mixture import GaussianMixture
from sklearn.preprocessing import StandardScaler
import warnings
warnings.filterwarnings('ignore')

# Load just the first 600 items
print("Loading data...")
digits_full = np.load('unlabelled_train_data_images.npy')
digits = digits_full[:600]  # Take only the first 600 images
X = digits.reshape(digits.shape[0], -1) / 255.0  # Flatten and normalize

print(f"Working with {len(digits)} digit images")

def create_enhanced_cnn_embeddings(X, encoding_dim=64):
    """
    Create embeddings using an enhanced CNN architecture specifically tuned for MNIST
    """
    # Reshape for CNN
    input_shape = (28, 28, 1)
    X_reshaped = X.reshape(-1, 28, 28, 1)
    
    # Define a more powerful CNN model specifically for MNIST
    inputs = tf.keras.Input(shape=input_shape)
    
    # First block - capture basic features
    x = layers.Conv2D(32, (5, 5), padding='same')(inputs)
    x = layers.BatchNormalization()(x)
    x = layers.LeakyReLU(alpha=0.2)(x)
    x = layers.MaxPooling2D((2, 2))(x)
    
    # Second block - more complex features
    x = layers.Conv2D(64, (3, 3), padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.LeakyReLU(alpha=0.2)(x)
    x = layers.Conv2D(64, (3, 3), padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.LeakyReLU(alpha=0.2)(x)
    x = layers.MaxPooling2D((2, 2))(x)
    
    # Third block - high-level features
    x = layers.Conv2D(128, (3, 3), padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.LeakyReLU(alpha=0.2)(x)
    
    # Flatten and create dense representations
    x = layers.Flatten()(x)
    x = layers.Dense(256)(x)
    x = layers.BatchNormalization()(x)
    x = layers.LeakyReLU(alpha=0.2)(x)
    x = layers.Dropout(0.3)(x)
    
    # Final embedding layer
    encoded = layers.Dense(encoding_dim, name='embedding')(x)
    
    # Model for feature extraction
    feature_extractor = models.Model(inputs=inputs, outputs=encoded)
    
    # Create powerful decoder for pretraining
    decoded = layers.Dense(256, activation='relu')(encoded)
    decoded = layers.Dense(7*7*64, activation='relu')(decoded)
    decoded = layers.Reshape((7, 7, 64))(decoded)
    decoded = layers.Conv2DTranspose(64, (3, 3), strides=2, padding='same', activation='relu')(decoded)
    decoded = layers.Conv2DTranspose(32, (3, 3), strides=2, padding='same', activation='relu')(decoded)
    decoded = layers.Conv2D(1, (5, 5), padding='same', activation='sigmoid')(decoded)
    
    # Build autoencoder
    autoencoder = models.Model(inputs=inputs, outputs=decoded)
    autoencoder.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.001), loss='mse')
    
    # Train with more epochs (smaller dataset allows this)
    print("Pretraining feature extractor (this is critical for good clustering)...")
    history = autoencoder.fit(
        X_reshaped, X_reshaped,  # Input and target are the same for autoencoder
        epochs=30,  # More epochs for better feature learning
        batch_size=32,  # Smaller batch size for better learning
        shuffle=True,
        validation_split=0.2,
        verbose=1
    )
    
    # Plot the training history
    plt.figure(figsize=(10, 4))
    plt.plot(history.history['loss'])
    plt.plot(history.history['val_loss'])
    plt.title('Autoencoder Training')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend(['Training', 'Validation'])
    plt.show()
    
    # Generate embeddings
    embeddings = feature_extractor.predict(X_reshaped)
    print(f"Generated embeddings with shape: {embeddings.shape}")
    
    return embeddings, feature_extractor, autoencoder

def optimal_clustering(embeddings):
    """
    Try multiple algorithms and parameters to find optimal clustering for MNIST
    """
    print("\nTrying different clustering approaches...")
    best_score = -1
    best_labels = None
    best_method = None
    
    # Since we know we want exactly 10 clusters (digits 0-9), set n_clusters=10
    n_clusters = 10
    
    # Try KMeans with multiple initializations
    print("Testing KMeans...")
    kmeans = KMeans(n_clusters=n_clusters, n_init=20, random_state=42)
    kmeans_labels = kmeans.fit_predict(embeddings)
    kmeans_score = silhouette_score(embeddings, kmeans_labels)
    print(f"  KMeans silhouette score: {kmeans_score:.4f}")
    
    if kmeans_score > best_score:
        best_score = kmeans_score
        best_labels = kmeans_labels
        best_method = "KMeans"
    
    # Try Gaussian Mixture Model
    print("Testing Gaussian Mixture Model...")
    gmm = GaussianMixture(n_components=n_clusters, covariance_type='full', n_init=10, random_state=42)
    gmm.fit(embeddings)
    gmm_labels = gmm.predict(embeddings)
    gmm_score = silhouette_score(embeddings, gmm_labels)
    print(f"  GMM silhouette score: {gmm_score:.4f}")
    
    if gmm_score > best_score:
        best_score = gmm_score
        best_labels = gmm_labels
        best_method = "GMM"
        
    # Try KMeans with standardized data
    print("Testing KMeans with standardization...")
    scaler = StandardScaler()
    embeddings_scaled = scaler.fit_transform(embeddings)
    kmeans_scaled = KMeans(n_clusters=n_clusters, n_init=20, random_state=42)
    kmeans_scaled_labels = kmeans_scaled.fit_predict(embeddings_scaled)
    kmeans_scaled_score = silhouette_score(embeddings_scaled, kmeans_scaled_labels)
    print(f"  KMeans+Scaling silhouette score: {kmeans_scaled_score:.4f}")
    
    if kmeans_scaled_score > best_score:
        best_score = kmeans_scaled_score
        best_labels = kmeans_scaled_labels
        best_method = "KMeans+Scaling"
    
    print(f"\nBest clustering method: {best_method} (score: {best_score:.4f})")
    return best_labels, best_score, best_method

def visualize_digit_clusters(embeddings, labels, digit_images):
    """
    Create improved visualizations of digit clusters
    """
    # Use t-SNE for dimensionality reduction
    print("\nCreating t-SNE visualization...")
    tsne = TSNE(n_components=2, perplexity=30, learning_rate='auto', n_iter=2000, random_state=42)
    embeddings_2d = tsne.fit_transform(embeddings)
    
    # Create scatter plot
    plt.figure(figsize=(12, 10))
    scatter = plt.scatter(
        embeddings_2d[:, 0], 
        embeddings_2d[:, 1],
        c=labels,
        cmap='tab10', 
        alpha=0.7,
        s=50
    )
    plt.colorbar(scatter, label='Cluster')
    plt.title('t-SNE Visualization of Digit Clusters', fontsize=16)
    plt.tight_layout()
    plt.show()
    
    # Show examples from each cluster
    visualize_cluster_examples(digit_images, labels)
    
    return embeddings_2d

def visualize_cluster_examples(images, labels, samples_per_row=10):
    """
    Visualize examples from each cluster in a grid
    """
    n_clusters = len(np.unique(labels))
    plt.figure(figsize=(15, 2*n_clusters))
    
    for cluster_id in range(n_clusters):
        cluster_samples = np.where(labels == cluster_id)[0]
        
        if len(cluster_samples) > 0:
            samples_to_show = min(samples_per_row, len(cluster_samples))
            for i in range(samples_to_show):
                plt.subplot(n_clusters, samples_per_row, cluster_id*samples_per_row + i + 1)
                plt.imshow(images[cluster_samples[i]].reshape(28, 28), cmap='gray')
                plt.axis('off')
                
                # Add cluster label to first image in row
                if i == 0:
                    plt.title(f'Cluster {cluster_id}')
                    
    plt.tight_layout()
    plt.subplots_adjust(hspace=0.3)
    plt.suptitle("Samples from Each Cluster", fontsize=16, y=0.92)
    plt.show()

def find_cluster_representatives(images, labels):
    """
    Find the best representative example for each cluster
    """
    n_clusters = len(np.unique(labels))
    representatives = []
    
    plt.figure(figsize=(15, 6))
    
    for i in range(n_clusters):
        cluster_indices = np.where(labels == i)[0]
        
        if len(cluster_indices) > 0:
            cluster_images = images[cluster_indices].reshape(len(cluster_indices), -1)
            cluster_center = np.mean(cluster_images, axis=0)
            
            # Find image closest to center
            distances = np.linalg.norm(cluster_images - cluster_center, axis=1)
            closest_idx = cluster_indices[np.argmin(distances)]
            representatives.append((i, closest_idx))
            
            # Display the representative
            plt.subplot(2, 5, i+1)
            plt.imshow(images[closest_idx].reshape(28, 28), cmap='gray')
            plt.title(f"Cluster {i}")
            plt.axis('off')
    
    plt.tight_layout()
    plt.suptitle("Representative Examples from Each Cluster", y=0.98, fontsize=16)
    plt.show()
    
    return representatives

def assign_digit_labels(representatives, labels):
    """
    After visual inspection, manually create mapping from clusters to digits.
    This step typically requires human intervention based on the centroids shown.
    """
    # This would normally be done manually after inspecting the representative examples
    # For now, we'll just provide a placeholder mapping that should be updated
    print("\nAfter examining the representative examples, please create a mapping from cluster IDs to digits.")
    print("For example: cluster_to_digit = {0: 7, 1: 2, 2: 1, 3: 0, 4: 4, ...}")
    
    # Placeholder - you should replace this with actual mapping after visualization
    cluster_to_digit = {}
    for i in range(10):
        cluster_to_digit[i] = i  # Default 1:1 mapping, needs to be adjusted manually
    
    # Apply mapping to labels
    digit_labels = np.zeros_like(labels)
    for cluster_id, digit in cluster_to_digit.items():
        digit_labels[labels == cluster_id] = digit
    
    return digit_labels, cluster_to_digit

# ====== Main Execution ======
print("Starting enhanced MNIST digit clustering pipeline...")

# Step 1: Generate high-quality embeddings with deep CNN
embeddings, extractor, autoencoder = create_enhanced_cnn_embeddings(X)

# Step 2: Perform optimal clustering for exactly 10 clusters
cluster_labels, score, method = optimal_clustering(embeddings)

# Step 3: Visualize the clusters
tsne_coords = visualize_digit_clusters(embeddings, cluster_labels, digits)

# Step 4: Find representative examples from each cluster
representatives = find_cluster_representatives(digits, cluster_labels)

# Step 5: Create mapping (this would typically require manual inspection)
digit_labels, mapping = assign_digit_labels(representatives, cluster_labels)

# Step 6: Save results
np.save('digit_labels_small_dataset.npy', cluster_labels)
np.savez('clustering_results_small_dataset.npz',
         original_labels=cluster_labels,
         embeddings=embeddings,
         tsne=tsne_coords,
         representatives=[r[1] for r in representatives])

print("\nClustering complete!")
print("Review the representative examples and manually update the cluster-to-digit mapping.")
print("You can modify the 'assign_digit_labels' function with the correct mapping.")