In [None]:
# Step 1: Import necessary libraries
from sklearn.datasets import fetch_openml
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt
import numpy as np

# Step 2: Fetch and preprocess the MNIST dataset
print("Fetching MNIST dataset...")
X, y = fetch_openml('mnist_784', version=1, return_X_y=True, as_frame=False)
X = X / 255.0  # Normalize pixel values to [0, 1]
y = y.astype(int)  # Convert labels to integers

# Step 3: Reduce dimensionality to 2D
print("Reducing dimensionality of MNIST dataset...")
pca = PCA(n_components=2)
X_reduced = pca.fit_transform(X)

# Step 4: Subsample data for visualization
sample_size = 7000  # Take a smaller subset for clarity
indices = np.random.choice(range(X_reduced.shape[0]), size=sample_size, replace=False)
X_sample = X_reduced[indices]
y_sample = y[indices]

# Step 5: Define k-d tree and ball tree methods (reuse from previous code)
# Include `build_kd_tree`, `plot_kd_tree`, `build_ball_tree`, and `plot_ball_tree` functions.

# Build and visualize the k-d tree
kd_tree_sample_root = build_kd_tree(X_sample)
fig, ax = plt.subplots(figsize=(8, 8))
ax.set_title("k-d Tree Visualization (MNIST Reduced to 2D)")
ax.set_xlabel("PCA Dimension 1")
ax.set_ylabel("PCA Dimension 2")
ax.set_xlim(X_sample[:, 0].min() - 1, X_sample[:, 0].max() + 1)
ax.set_ylim(X_sample[:, 1].min() - 1, X_sample[:, 1].max() + 1)
ax.grid(True)
plot_kd_tree(kd_tree_sample_root, ax, bounds=(X_sample[:, 0].min(), X_sample[:, 0].max(),
                                              X_sample[:, 1].min(), X_sample[:, 1].max()))
plt.show()

# Build and visualize the ball tree
ball_tree_sample_root = build_ball_tree(X_sample)
fig, ax = plt.subplots(figsize=(8, 8))
ax.set_title("Ball Tree Visualization (MNIST Reduced to 2D)")
ax.set_xlabel("PCA Dimension 1")
ax.set_ylabel("PCA Dimension 2")
ax.set_xlim(X_sample[:, 0].min() - 1, X_sample[:, 0].max() + 1)
ax.set_ylim(X_sample[:, 1].min() - 1, X_sample[:, 1].max() + 1)
ax.grid(True)
plot_ball_tree(ball_tree_sample_root, ax=ax)
plt.show()


Fetching MNIST dataset...
