In [None]:
!pip install face-recognition
!pip install scikit-learn

In [None]:
import math
from sklearn import neighbors
import os
import os.path
import pickle
from PIL import Image, ImageDraw
import face_recognition
from face_recognition.face_recognition_cli import image_files_in_folder

In [None]:
ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg'}


def train(train_dir, model_save_path=None, n_neighbors=None, knn_algo='ball_tree', verbose=False):
    """
    训练一个用于人脸识别的k-最近邻分类器。
         （请在源代码中查看train_dir示例树结构）

         结构:
            <train_dir>/
            ├── <person1>/
            │   ├── <somename1>.jpeg
            │   ├── <somename2>.jpeg
            │   ├── ...
            ├── <person2>/
            │   ├── <somename1>.jpeg
            │   └── <somename2>.jpeg
            └── ...

    :param train_dir: 包含每个已知人物的子目录的目录，带有其名称。
    :param model_save_path:（可选）保存模型到磁盘的路径
    :param n_neighbors:（可选）在分类中加权的邻居数。如果未指定，则会自动选择。
    :param knn_algo:（可选）支持knn的基础数据结构。默认为ball_tree。
    :param verbose: 训练的详细程度
    :return: 返回在给定数据上训练的knn分类器。
    """

    X = []  # 人脸编码
    y = []  # 对应标签
    print("正在加载KNN训练图像")
    # 循环遍历训练数据集中的每个人的目录（文件夹）
    long = len(os.listdir(train_dir))
    for i,class_dir in enumerate(os.listdir(train_dir)):
        if not os.path.isdir(os.path.join(train_dir, class_dir)):
            continue
        print(f"{i}/{long}")
        # 遍历每个人物目录中的图像文件
        for img_path in image_files_in_folder(os.path.join(train_dir, class_dir)):
            # 加载图片
            image = face_recognition.load_image_file(img_path)
            # 检测图像中的人脸位置
            face_bounding_boxes = face_recognition.face_locations(image)
            # 如果图像中没有检测到人脸或者检测到多张人脸，将跳过该图像
            if len(face_bounding_boxes) != 1:
                # If there are no people (or too many people) in a training image, skip the image.
                if verbose:
                    print("Image {} not suitable for training: {}".format(img_path, "Didn't find a face" if len(face_bounding_boxes) < 1 else "Found more than one face"))
            else:
                # 将该人脸的特征向量添加到训练集X中
                X.append(face_recognition.face_encodings(image, known_face_locations=face_bounding_boxes)[0])
                # 作为标签添加到训练集y中（目录名）
                y.append(class_dir)

    # 如果没有指定邻居数量，将根据训练数据集X的大小来自动选择一个合适的邻居数量
    if n_neighbors is None:
        n_neighbors = int(round(math.sqrt(len(X))))
        if verbose:
            print("邻居数量为:", n_neighbors)

    # 创建一个KNN分类器对象，其中n_neighbors表示邻居数量，algorithm表示KNN算法，weights='distance'表示使用距离加权
    knn_clf = neighbors.KNeighborsClassifier(n_neighbors=n_neighbors, algorithm=knn_algo, weights='distance')
    # 使用训练数据集X和对应的标签y来训练KNN分类器
    print("正在训练KNN分类器")
    knn_clf.fit(X, y)

    # 如果指定了模型保存路径，将训练好的KNN分类器保存到指定路径
    if model_save_path is not None:
        with open(model_save_path, 'wb') as f:
            pickle.dump(knn_clf, f)
    # 返回训练好的KNN分类器
    return knn_clf


def predict(X_img_path, knn_clf=None, model_path=None, distance_threshold=0.6):
    """
    使用训练好的KNN分类器在给定图像中识别人脸

    :param X_img_path: 待识别图像的路径
    :param knn_clf: （可选）一个KNN分类器对象。如果未指定，则必须指定model_path。
    :param model_path: （可选）一个已存储的KNN分类器的路径。如果未指定，则必须指定knn_clf。
    :param distance_threshold: （可选）人脸分类的距离阈值。值越大，将未知人物误分类为已知人物的机会越大。
    :return: 图像中识别出的人脸的名称和边界框的列表：[(name, bounding box), ...]。
            对于无法识别的人物，将返回名称'unknown'。
    """
    # 识别图片是否存在
    if not os.path.isfile(X_img_path) or os.path.splitext(X_img_path)[1][1:] not in ALLOWED_EXTENSIONS:
        raise Exception("Invalid image path: {}".format(X_img_path))
    # 检查是否有KNN分类器
    if knn_clf is None and model_path is None:
        raise Exception("Must supply knn classifier either thourgh knn_clf or model_path")

    # 没有提供KNN分类器（knn_clf），代码将尝试从提供的模型路径加载已经训练好的KNN模型
    if knn_clf is None:
        with open(model_path, 'rb') as f:
            knn_clf = pickle.load(f)

    # 加载图像
    X_img = face_recognition.load_image_file(X_img_path)
    # 获取位置
    X_face_locations = face_recognition.face_locations(X_img)
    # 重点调试！，看看单人的数据类型是什么
    # 没找到人脸
    if len(X_face_locations) == 0:
        return []
    # else:
    #     X_face_locations = X_face_locations[0]
    # 对于在测试图像中找到的每个人脸，使用已知的人脸位置计算人脸编码
    faces_encodings = face_recognition.face_encodings(X_img, known_face_locations=X_face_locations)

    # 使用KNN模型，查找测试图像中每个人脸与训练数据集中最相似的人脸
    closest_distances = knn_clf.kneighbors(faces_encodings, n_neighbors=1)
    # 对于每个人脸，检查其与最近的一个相似人脸的距离是否小于或等于指定的距离阈值（distance_threshold）。如果是，认为是匹配的人脸。
    are_matches = [closest_distances[0][i][0] <= distance_threshold for i in range(len(X_face_locations))]

    # 结果返回为一个包含已识别人脸名称和位置的列表。对于未识别的人脸，名称为"unknown"。只有距离小于等于阈值的人脸才会被识别为已知人脸。
    return [(pred, loc) if rec else ("unknown", loc) for pred, loc, rec in zip(knn_clf.predict(faces_encodings), X_face_locations, are_matches)]


def show_prediction_labels_on_image(img_path, predictions):
    """
    通过可视化方式展示人脸识别结果。

    :param img_path: 待识别图像的路径
    :param predictions: predict 函数的结果
    :return:
    """
    pil_image = Image.open(img_path).convert("RGB")
    draw = ImageDraw.Draw(pil_image)

    for name, (top, right, bottom, left) in predictions:
        # Draw a box around the face using the Pillow module
        draw.rectangle(((left, top), (right, bottom)), outline=(0, 0, 255))

        # There's a bug in Pillow where it blows up with non-UTF-8 text
        # when using the default bitmap font
        name = name.encode("UTF-8")

        # Draw a label with a name below the face
        text_width, text_height = draw.textsize(name)
        draw.rectangle(((left, bottom - text_height - 10), (right, bottom)), fill=(0, 0, 255), outline=(0, 0, 255))
        draw.text((left + 6, bottom - text_height - 5), name, fill=(255, 255, 255, 255))

    # Remove the drawing library from memory as per the Pillow docs
    del draw

    # Display the resulting image
    pil_image.show()


In [None]:
if __name__ == "__main__":
    train_addr = r'/kaggle/input/knn-cas-tarin5-test3-bsd/KNN_CSA_train_5_bsd'
    test_addr = r'/kaggle/input/knn-cas-tarin5-test3-bsd/KNN_CSA_test_3_bsd'
    
    # 训练分类器
    print("开始训练KNN分类器...")
    classifier = train(train_addr, model_save_path="knn-cas-train5-test3-bsd.clf", n_neighbors=5)
    print("KNN分类器的训练已完成!")
    
    # STEP 2: 使用分类器对未知图像进行预测
    kk_l = [0.50,0.52,0.54,0.56,0.58,0.60,0.62,0.64,0.66]
    # 遍历所有的子文件夹
    for kk in kk_l:
        yes = 0.0 # 正确数量
        continiu = 0 # 跳过数量
        png_num = 0.0 # 总图片数
        for image_name in os.listdir(test_addr):
            # 路径拼接，获取当前遍历的图像文件的完整路径。
            full_file_path_zi = os.path.join(test_addr, image_name)
            for image_file in os.listdir(full_file_path_zi):
                full_file_path = os.path.join(full_file_path_zi, image_file)
                # print("正在查找图像中的人脸: {}".format(image_name))
                # Find all people in the image using a trained classifier model
                # Note: You can pass in either a classifier file name or a classifier model instance
                # 使用knn分类器进行识别,界限为kk
                predictions = predict(full_file_path, model_path="knn-cas-train5-test3-bsd.clf",distance_threshold=kk)

                # 打印已识别出的人脸的名称
                for name, (top, right, bottom, left) in predictions:
                    if name == "unknown":
                        continiu = continiu + 1
    #                     print(f"无法识别 | 输入的文件为:{image_name} | 无预测文件{''}")
                    else:
                        png_num = png_num + 1
                        if name == image_name:
                            yes = yes +1
#                             print(f"预测正确 | 正确率为:{100*yes/png_num}% | 输入的文件为:{image_name} | 预测的文件为{name} | 跳过了{continiu}张")
#                         else:
#                             print(f"预测错误 | 正确率为:{100*yes/png_num}% | 输入的文件为:{image_name} | 预测的文件为{name} | 跳过了{continiu}张")
                # # Display results overlaid on an image
                # show_prediction_labels_on_image(os.path.join("knn_examples/test", image_file), predictions)
        print(f"阈值为:{kk} | 总共识别了{(png_num)+(continiu)}张图片 | 正确识别{int(yes)}张 | 正确率为:{100.0*yes/png_num}% | 跳过了{continiu}张 | 跳过率为:{100.0*continiu/(continiu+png_num)}%")
