In [1]:
import random
import numpy as np
import matplotlib.pyplot as plt
from collections import Counter
from sklearn.datasets import load_iris
from sklearn.neighbors import KNeighborsClassifier
from sklearn.model_selection import train_test_split

In [2]:
X, y = load_iris(return_X_y=True)

In [3]:
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.15, random_state=42)

In [4]:
class CondensedNearestNeighbors:
    def __init__(self, k_neighbors=1):
        self.k_neighbors = k_neighbors

    def fit_resample(self, X, y):
        class_counts = Counter(y)
        min_class_size = min(class_counts.values())
        classes = np.unique(y)

        X_resampled = []
        y_resampled = []

        for class_label in classes:
            X_class = X[y == class_label]
            y_class = y[y == class_label]

            idx_selected = [0]
            knn = KNeighborsClassifier(n_neighbors=self.k_neighbors)

            while len(idx_selected) < min_class_size:
                X_selected = X_class[idx_selected]
                y_selected = y_class[idx_selected]

                knn.fit(X_selected, y_selected)

                for i in range(len(X_class)):
                    if i not in idx_selected:
                        y_pred = knn.predict([X_class[i]])
                        if y_pred != y_class[i]:
                            idx_selected.append(i)

                if len(idx_selected) > min_class_size:
                    idx_selected = np.random.choice(idx_selected, min_class_size, replace=False)

            X_resampled.append(X_class[idx_selected])
            y_resampled.append(y_class[idx_selected])

        X_resampled = np.vstack(X_resampled)
        y_resampled = np.hstack(y_resampled)
        return X_resampled, y_resampled

In [None]:
cnn = CondensedNearestNeighbors()
X_train_resampled, y_train_resampled = cnn.fit_resample(X_train, y_train)

In [None]:
print("y_train:", Counter(y_train))
print("y_train_resampled:", Counter(y_train_resampled))

In [None]:
plt.figure(figsize=(12, 8))
plt.scatter(X_train[:, 0], y_train, color="r", marker="o")
plt.scatter(X_train_resampled[:, 0], y_train_resampled, color="b", marker="*")
plt.legend(["Raw Data", "CNN Generated"])
plt.show()