In [1]:
import numpy as np
from collections import Counter

class KNNClassifier:
    def __init__(self, k=3):
        self.k = k

    def fit(self, X, y):
        self.X_train = X
        self.y_train = y

    def euclidean_distance(self, x1, x2):
        return np.sqrt(np.sum((x1 - x2)**2))

    def predict(self, X):
        y_pred = []
        for x in X:
            distances = []
            for i, x_train in enumerate(self.X_train):
                distance = self.euclidean_distance(x, x_train)
                distances.append((distance, self.y_train[i]))
            distances.sort(key=lambda x: x[0])
            k_nearest = distances[:self.k]
            k_nearest_labels = [label for _, label in k_nearest]
            most_common = Counter(k_nearest_labels).most_common(1)
            y_pred.append(most_common[0][0])
        return y_pred


In [3]:
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import accuracy_score

X, y = make_classification(n_samples=1000, n_features=2, n_informative=2,
                                               n_redundant=0, n_clusters_per_class=1, random_state=0)

# Splitdataset into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, random_state=0)

# Create an instance of the KNN classifier from scikit-learn
knn = KNeighborsClassifier(n_neighbors=3)

# Fit the KNN classifier on the synthetic training data
knn.fit(X_train, y_train)
y_pred = knn.predict(X_test)

# Calculate and print the accuracy of the KNN classifier
accuracy = accuracy_score(y_test, y_pred)
print("Accuracy on test data:", accuracy)

# Calculting by my function 
knn_by_my_function=KNNClassifier(k=3)
knn_by_my_function.fit(X_train, y_train)
y_pred_by_my_function = knn.predict(X_test)

# Calculate and print the accuracy of the my KNN classifier
accuracy2 = accuracy_score(y_test, y_pred_by_my_function)
print("Accuracy on test data:", accuracy2)


Accuracy on synthetic test data: 0.915
Accuracy on synthetic test data: 0.915
