# k近邻算法---利用k近邻算法对手写字体进行识别

In [1]:
import torch
import numpy as np
import operator
from os import listdir


In [38]:
def KNN(inx,dataset,labels,k,distances_way):
    """
    inx:输入需要分类的数字
    datset:输入样本训练集
    labels:标签向量
    k:选择最近邻的数目,其中标签数量和矩阵dataset的行数相同
    distance:计算距离的方式
    """
    #计算距离
    datasize_h = dataset.shape[0]
    inx = np.tile(inx,(datasize_h,1)) #将inx维度拓展成和dataset形状相同的的矩阵
    
    if distances_way == str('o'):#欧几里得距离
        diffmat = inx - dataset
        sq_diffmat = diffmat**2
        sq_distances = sq_diffmat.sum(axis=1)
        distance = sq_distances**0.5
    elif distances_way == str('man'):#曼哈顿距离
        diffmat =  inx - dataset
        abs_diffmat = abs(diffmat)
        distance = abs_diffmat.sum(axis=1)
    elif distances_way == str('min'):#闵可夫斯基距离
        p = int(input('输入p值:'))
        diffmat = inx - dataset
        sq_diffmat = diffmat**2
        sq_distances = sq_diffmat.sum(axis=1)
        distance = sq_distances**(1/p)
    distance_sort = distance.argsort() #按距离有小到大排序
    #return distance_sort
    #将排序得到的距离和我们的标签进行对应起来，利用哈希表
    dic = {}
    for i in range(k):
        diff_label = labels[distance_sort[i]]
        dic[diff_label] = dic.get(diff_label,0)+1
    dic_sort = sorted(dic.items(), key=operator.itemgetter(1),reverse=True)
    return dic_sort[0][0]


In [39]:
#将32*32化成1*1024的矩阵，也可以化成32*32的矩阵
def data_read(path):
    data = open(path)
    data_use = np.zeros((1,1024))
    for i in range(32):
        data_line = data.readline() #读取每一行
        for j in range(32):
            data_use[0,32*i+j] = int(data_line[j])
    return data_use

#将32*32的text文件化成32*32的矩阵
# def data_read(path):
#     data = open(path)
#     data_use = np.zeros([32,32])
#     data_narry = np.array(data)
#     for j in range(len(a)):
#         for i in range(len(a)):
#             data_use[i][j] = a[i][j]
#     return data_use

In [44]:
#第二步批量文件读取,并且得到所有数字所对应的标签
def txt_read(path):
    """
    path:文件夹地址,利用os.listdir进行读取
    file_label:每个txt文档所所对应的数字
    """
    file_list = listdir(path) #得到文件下每一个文件的名称
    #return file_list
    file_length = len(file_list)
    file_label = []
    data = np.zeros((file_length,1024))
    for i in range(file_length):
        file_name = file_list[i]
        file_str = file_name.split('.')[0]
        file_label.append(int(file_str.split('_')[0]))
        data[i,:] = data_read(path+'/%s'%file_name)
    return file_label,data

In [53]:
def handwritingClassTest():
    """
    hwLabels:手写数字真实值
    m:训练集文件个数 mTset:测试集的文件个数
    trainingMat:存储训练集的全部一维化的数据
    """
    hwLabels = []
    trainingFileList = listdir('D:/Github-code/python-gogogo/机器学习/分类问题/K近邻算法/手写数据文件/training_handwriting')
    m = len(trainingFileList) # m=1934
    trainingMat = np.zeros((m, 1024))
    for i in range(m):
        # 对文件名进行拆分 --->开始 只取文件名的第一个字符(对应真实数字)
        fileNameStr = trainingFileList[i] # 得到的都是文件名的字符串
        fileStr = fileNameStr.split('.')[0]
        classNumStr = int(fileStr.split('_')[0])
        hwLabels.append(classNumStr)
        # 对文件名进行拆分 --->结束 
        trainingMat[i,:] = data_read('D:/Github-code/python-gogogo/机器学习/分类问题/K近邻算法/手写数据文件/training_handwriting/%s'% fileNameStr)
    #return trainingMat

    testFileList = listdir('D:/Github-code/python-gogogo/机器学习/分类问题/K近邻算法/手写数据文件/test_handwriting')
    errorCount = 0.0
    mTest = len(testFileList)
    for i in range(mTest):
        fileNameStr = testFileList[i]
        fileStr = fileNameStr.split('.')[0]     
        classNumStr = int(fileStr.split('_')[0])

        vectorUnderTest = data_read('D:/Github-code/python-gogogo/机器学习/分类问题/K近邻算法/手写数据文件/test_handwriting/%s' % fileNameStr)
        print(vectorUnderTest.shape)
        classifierResult = KNN(vectorUnderTest, trainingMat, hwLabels, 3,'o')
        print ("KNN分类结果: %d, 实际结果: %d" % (classifierResult, classNumStr))
        if (classifierResult != classNumStr): errorCount += 1.0
    print ("\n错误数字数量: %d" % errorCount)
    print ("\n错误比率: %f" % (errorCount/float(mTest)))

In [54]:
handwritingClassTest()

(1, 1024)
(1, 1024)
(1, 1024)
(1, 1024)
(1, 1024)
(1, 1024)
(1, 1024)
(1, 1024)
(1, 1024)
(1, 1024)
(1, 1024)
(1, 1024)
(1, 1024)
(1, 1024)
(1, 1024)
(1, 1024)
(1, 1024)
(1, 1024)
(1, 1024)
(1, 1024)
(1, 1024)
(1, 1024)
(1, 1024)
(1, 1024)
(1, 1024)
(1, 1024)
(1, 1024)
(1, 1024)
(1, 1024)
(1, 1024)
(1, 1024)
(1, 1024)
(1, 1024)
(1, 1024)
(1, 1024)
(1, 1024)
(1, 1024)
(1, 1024)
(1, 1024)
(1, 1024)
(1, 1024)
(1, 1024)
(1, 1024)
(1, 1024)
(1, 1024)
(1, 1024)
(1, 1024)
(1, 1024)
(1, 1024)
(1, 1024)
(1, 1024)
(1, 1024)
(1, 1024)
(1, 1024)
(1, 1024)
(1, 1024)
(1, 1024)
(1, 1024)
(1, 1024)
(1, 1024)
(1, 1024)
(1, 1024)
(1, 1024)
(1, 1024)
(1, 1024)
(1, 1024)
(1, 1024)
(1, 1024)
(1, 1024)
(1, 1024)
(1, 1024)
(1, 1024)
(1, 1024)
(1, 1024)
(1, 1024)
(1, 1024)
(1, 1024)
(1, 1024)
(1, 1024)
(1, 1024)
(1, 1024)
(1, 1024)
(1, 1024)
(1, 1024)
(1, 1024)
(1, 1024)
(1, 1024)
(1, 1024)
(1, 1024)
(1, 1024)
(1, 1024)
(1, 1024)
(1, 1024)
(1, 1024)
(1, 1024)
(1, 1024)
(1, 1024)
(1, 1024)
(1, 1024)
(1, 1024)
