In [None]:
import numpy as np
import pandas as pd  # اضافه کردن pandas
from sklearn.datasets import fetch_openml
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import accuracy_score

# 1. بارگذاری دیتاست MNIST و تبدیل داده‌ها
mnist = fetch_openml('mnist_784', version=1)

# تبدیل داده‌ها به صورت عددی (و اطمینان از حذف مقادیر رشته‌ای)
X = mnist.data.apply(pd.to_numeric, errors='coerce').fillna(0).astype('float32').to_numpy()
y = mnist.target.astype(int)

# 2. تقسیم داده‌ها به آموزش و تست
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# 3. تعریف تابع شیفت دادن تصویر
def shift_image(image, direction):
    image = image.reshape(28, 28)  # تبدیل تصویر 1D به 2D
    if direction == 'left':
        shifted = np.roll(image, -1, axis=1)
        shifted[:, -1] = 0  # ستون‌های جدید با صفر پر شوند
    elif direction == 'right':
        shifted = np.roll(image, 1, axis=1)
        shifted[:, 0] = 0
    elif direction == 'up':
        shifted = np.roll(image, -1, axis=0)
        shifted[-1, :] = 0
    elif direction == 'down':
        shifted = np.roll(image, 1, axis=0)
        shifted[0, :] = 0
    else:
        raise ValueError("Direction must be one of ['left', 'right', 'up', 'down']")
    return shifted.flatten()

# 4. تولید داده‌های افزوده
X_train_augmented = []
y_train_augmented = []

for image, label in zip(X_train, y_train):
    X_train_augmented.append(image)
    y_train_augmented.append(label)
    for direction in ['left', 'right', 'up', 'down']:
        shifted_image = shift_image(image, direction)
        X_train_augmented.append(shifted_image)
        y_train_augmented.append(label)

X_train_augmented = np.array(X_train_augmented)
y_train_augmented = np.array(y_train_augmented)

# 5. مدل KNeighborsClassifier
knn = KNeighborsClassifier(n_neighbors=5, weights='distance')
knn.fit(X_train_augmented, y_train_augmented)

# 6. ارزیابی مدل
y_pred = knn.predict(X_test)
accuracy = accuracy_score(y_test, y_pred)

print("Accuracy after data augmentation:", accuracy)
