-
Notifications
You must be signed in to change notification settings - Fork 0
Description
%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
from collections import Counter
from sklearn.datasets import load_iris, load_digits
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import KFold, train_test_split
from sklearn.metrics import accuracy_score
import pandas as pd
from IPython.display import display
class KNNClassifier:
"""
KNN-классификатор с нуля.
Поддерживает евклидову и манхэттенскую метрики,
равномерное и взвешенное по расстоянию голосование.
"""
def __init__(self, k=5, metric='euclidean', weighting='uniform', eps=1e-8):
self.k = k
self.metric = metric
self.weighting = weighting
self.eps = eps
self.X_train = None
self.y_train = None
def _distance(self, x1, x2):
if self.metric == 'euclidean':
return np.sqrt(np.sum((x1 - x2) ** 2))
elif self.metric == 'manhattan':
return np.sum(np.abs(x1 - x2))
else:
raise ValueError(f"Метрика {self.metric} не поддерживается")
def _predict_one(self, x):
distances = [(self._distance(x, x_train), i) for i, x_train in enumerate(self.X_train)]
distances.sort(key=lambda d: d[0])
k_nearest = distances[:self.k]
if self.weighting == 'uniform':
labels = [self.y_train[i] for _, i in k_nearest]
return Counter(labels).most_common(1)[0][0]
else:
weight_sum = {}
for dist, i in k_nearest:
w = 1.0 / (dist + self.eps)
label = self.y_train[i]
weight_sum[label] = weight_sum.get(label, 0) + w
return max(weight_sum.items(), key=lambda kv: kv[1])[0]
def fit(self, X, y):
self.X_train = np.array(X)
self.y_train = np.array(y)
return self
def predict(self, X):
X = np.array(X)
return np.array([self._predict_one(x) for x in X])
====================== Функция для полного анализа датасета ======================
def analyze_dataset(dataset_name, X, y, k_max=15, cv_folds=5):
print(f"\n========== Анализ датасета: {dataset_name} ==========")
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)
k_range = range(1, k_max+1)
metrics = ['euclidean', 'manhattan']
weightings = ['uniform', 'distance']
results = []
styles = {
('euclidean', 'uniform'): ('b-', 'Евклидова, равномерное'),
('euclidean', 'distance'): ('g-', 'Евклидова, взвешенное'),
('manhattan', 'uniform'): ('r-', 'Манхэттенская, равномерное'),
('manhattan', 'distance'): ('m-', 'Манхэттенская, взвешенное')
}
fig, ax = plt.subplots(figsize=(10, 6))
for metric in metrics:
for weighting in weightings:
cv_scores_by_k = []
for k in k_range:
kf = KFold(n_splits=cv_folds, shuffle=True, random_state=42)
scores = []
for train_idx, val_idx in kf.split(X_scaled):
X_train, X_val = X_scaled[train_idx], X_scaled[val_idx]
y_train, y_val = y[train_idx], y[val_idx]
model = KNNClassifier(k=k, metric=metric, weighting=weighting)
model.fit(X_train, y_train)
y_pred = model.predict(X_val)
scores.append(accuracy_score(y_val, y_pred))
mean_score = np.mean(scores)
cv_scores_by_k.append(mean_score)
results.append({
'k': k,
'metric': metric,
'weighting': weighting,
'accuracy': mean_score
})
label = styles[(metric, weighting)][1]
color_style = styles[(metric, weighting)][0]
ax.plot(k_range, cv_scores_by_k, color_style, linewidth=2, label=label)
df_results = pd.DataFrame(results)
best_idx = df_results['accuracy'].idxmax()
best_row = df_results.loc[best_idx]
print(f"\nЛучшие параметры (по кросс-валидации):")
print(f" k = {best_row['k']:.0f}")
print(f" метрика = {best_row['metric']}")
print(f" взвешивание = {best_row['weighting']}")
print(f" средняя точность = {best_row['accuracy']:.4f}")
ax.set_xlabel('k (число соседей)', fontsize=12)
ax.set_ylabel('Средняя точность на кросс-валидации', fontsize=12)
ax.set_title(f'{dataset_name}: зависимость точности от k', fontsize=14)
ax.legend(loc='best')
ax.grid(True, alpha=0.3)
fig.tight_layout()
return best_row.to_dict(), df_results, fig
iris = load_iris()
digits = load_digits()
best_iris, df_iris, fig_iris = analyze_dataset("Iris", iris.data, iris.target, k_max=15)
display(fig_iris)
plt.close(fig_iris)
best_digits, df_digits, fig_digits = analyze_dataset("Digits", digits.data, digits.target, k_max=15)
display(fig_digits)
plt.close(fig_digits)
print("\n========== Исследование завершено ==========")