<a href="https://colab.research.google.com/github/zing53/slr/blob/main/slr_final.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# 环境

In [None]:
# @title 获取代码，安装相应的库
!git clone https://github.com/zing53/slr.git
!pip install joblib scikit-learn mediapipe

In [None]:
# @title 执行一次，重启会话，否则MediaPipe报错
import os
os._exit(0)

# 处理数据和训练
如直接使用提供的模型，可跳过此环节

In [None]:
# @title 获取数据集
!cd /content/slr/data/; git clone https://github.com/ardamavi/Sign-Language-Digits-Dataset.git

Cloning into 'Sign-Language-Digits-Dataset'...
remote: Enumerating objects: 2095, done.[K
remote: Counting objects: 100% (6/6), done.[K
remote: Compressing objects: 100% (6/6), done.[K
remote: Total 2095 (delta 2), reused 0 (delta 0), pack-reused 2089 (from 1)[K
Receiving objects: 100% (2095/2095), 15.07 MiB | 24.22 MiB/s, done.
Resolving deltas: 100% (660/660), done.


In [None]:
# @title 提取数据集特征
!cd /content/slr/; python3 -u feature_extraction.py 2>extract_log.txt

21点手部特征数据已成功保存至./data/hand_landmarks.csv
数据处理完成，已保存为processed_hand_landmarks.csv


In [None]:
# @title 训练knn模型
!cd /content/slr/; python3 -u train.py 2>train_log.txt

# 运行

In [None]:
# @title 选择图片进行手势识别
import ipywidgets as widgets
import os
import cv2
from google.colab.patches import cv2_imshow
from IPython.display import display, clear_output
from PIL import Image
from slr.process_and_classify import detector, drawer_crop, classifier

PIC_DIR = "/content/slr/data/samples/"

# 处理图片并进行预测
def process_image(image_path):

    image, results = detector(image_path)
    # 画出关键点
    annotated_image_path = drawer_crop(image, results)
    # print("Annotated Image Path:", annotated_image_path)
    annotated_image = cv2.imread(annotated_image_path)
    # 限制生成的图片大小（宽高最多200px）
    max_size = 200
    h, w = annotated_image.shape[:2]
    scale = min(max_size / h, max_size / w)
    new_size = (int(w * scale), int(h * scale))
    resized_image = cv2.resize(annotated_image, new_size)
    # 显示调整大小后的图片
    cv2_imshow(resized_image)

    if not results.hand_landmarks:
        print("没有检测到手势")
        return

    prediction, max_probs, time_used = classifier(results)
    # 获取文件名中的真实标签
    true_label = int(os.path.basename(selected_image_path).split('_')[1].split('.')[0])

    # 输出预测结果
    print(f"预测的手势类别：{prediction}，预测置信度：{max_probs:.4f},预测耗时：{time_used:.6f} 秒")
    if prediction == true_label:
        print("预测正确！")
    else:
        print(f"预测错误，真实类别是 {true_label}")

# 创建UI组件
output = widgets.Output()
submit_button = widgets.Button(description='识别手势')
selected_image_path = None  # 记录选中的图片

# 处理按钮点击事件
def on_button_clicked(b):
    global selected_image_path
    with output:
        clear_output()
        if selected_image_path:
            print(f"选中的图片: {selected_image_path}")
            process_image(selected_image_path)
        else:
            print("请先选择图片！")

# 显示选择界面
def show_images():
    global selected_image_path
    selected_image_path = None  # 初始化为None，表示没有选择任何图片

    # Get all files in the directory
    all_files = os.listdir(PIC_DIR)
    # Filter out only image files based on expected naming pattern
    image_names = sorted([f for f in all_files if f.startswith('IMG_') and f.endswith(('.jpg', '.png'))],
                         key=lambda x: int(x.split('_')[1].split('.')[0]))
    image_widgets = []

    # 存储当前选中按钮的变量
    selected_button = [None]

    # 创建输出区域，只用于显示选择的文本信息
    output = widgets.Output(
        layout=widgets.Layout(
            padding='10px',
            margin='10px 0',
            min_height='10px'
        )
    )

    # 按钮点击处理函数
    def on_button_click(b):
        global selected_image_path

        # 如果之前有选中的按钮，恢复其样式
        if selected_button[0] is not None:
            selected_button[0].button_style = ''
            selected_button[0].description = '选择'

        # 更新当前选中的按钮
        if selected_button[0] == b:
            # 如果点击的是当前选中按钮，取消选择
            selected_button[0] = None
            selected_image_path = None
        else:
            # 选中新按钮
            selected_button[0] = b
            b.button_style = 'success'
            b.description = '已选择'
            selected_image_path = b.img_path

        # 更新输出信息
        with output:
            clear_output()
            if selected_image_path:
                image_num = os.path.basename(selected_image_path).split('_')[1].split('.')[0]
                print(f"已选择图片{image_num}")
            else:
                print("未选择任何图片")


    # 设定统一的图片高度
    max_height = 120
    # 创建图片缩略图和按钮
    for name in image_names:
        image_path = os.path.join(PIC_DIR, name)
        # 提取文件名中的数字
        image_number = name.split('_')[1].split('.')[0]

        # 读取图片并调整大小（保持高度一致，宽度按比例缩放）
        img = Image.open(image_path)
        aspect_ratio = img.width / img.height
        new_width = int(max_height * aspect_ratio)  # 计算按比例调整后的宽度

        # 创建图片缩略图（使用计算出的宽度）
        image_widget = widgets.Image.from_file(
            image_path,
            width=new_width,
            height=max_height
        )

        # 创建图片标签
        label = widgets.Label(f"数字{image_number}", layout=widgets.Layout(text_align='center'))

        # 创建图片选择按钮
        select_btn = widgets.Button(
            description='选择',
            button_style='',
            layout=widgets.Layout(width='80px')
        )
        select_btn.img_path = image_path  # 存储图片路径作为按钮属性
        select_btn.on_click(on_button_click)

        image_box = widgets.VBox(
            [image_widget, label, select_btn],
            layout=widgets.Layout(
                align_items='center',
                border='1px solid #ddd',
                padding='10px',
                margin='5px',
                border_radius='8px',
                width='auto',  # 让容器宽度自适应内容
                min_width=f'{max(new_width, 100)}px'  # 确保容器至少与图片一样宽
            )
        )

        image_widgets.append(image_box)

    # 创建图片网格
    grid = widgets.GridBox(
        image_widgets[:10],  # 限制最多10张
        layout=widgets.Layout(
            grid_template_columns="repeat(5, auto)",
            grid_gap="15px",
            padding='10px'
        )
    )

    # 创建标题
    title = widgets.HTML(
        value="<h2 style='text-align:center; color:#4682B4;'>基于MediaPipe和KNN的手势识别系统</h2>"
    )

    # 组合所有元素
    main_container = widgets.VBox([
        title,
        grid,
        output
    ])

    display(main_container)
    # 初始显示
    with output:
        print("请选择一张图片")

# 绑定按钮事件并显示
show_images()
submit_button.on_click(on_button_clicked)
display(submit_button, output)

In [None]:
# @title 上传图片进行手势识别
import cv2
import ipywidgets as widgets
from IPython.display import display, clear_output
from google.colab import files
from google.colab.patches import cv2_imshow
from slr.process_and_classify import detector, drawer_crop, classifier

# 全局变量，存储上传的图片和路径
uploaded_image = None
IMAGE_FILE = None
output_area = widgets.Output()

def display_resized(image_path, size):
    image = cv2.imread(image_path)
    # 限制生成的图片大小（宽高最多size px）
    max_size = size
    h, w = image.shape[:2]
    scale = min(max_size / h, max_size / w)
    new_size = (int(w * scale), int(h * scale))
    resized_image = cv2.resize(image, new_size)
    cv2_imshow(resized_image)


# 上传图片
def upload_files():
    global uploaded_image, IMAGE_FILE
    # 清空之前的输出
    clear_output(wait=True)
    output_area.clear_output()

    display(upload_button)

    print("请提供单手手势图片")
    print("样例：")
    example = '/content/slr/data/samples/IMG_7.png'
    display_resized(example, 100)

    display(output_area)

    with output_area:
        uploaded = files.upload()

        if len(uploaded.keys()):
            IMAGE_FILE = next(iter(uploaded))
            uploaded_image = uploaded[IMAGE_FILE]

            # 将图片写入文件
            with open(IMAGE_FILE, 'wb') as f:
                f.write(uploaded_image)

            print('已上传图片:')
            raw = IMAGE_FILE
            display_resized(raw, 200)
            display(recognize_button)# 显示识别按钮
        else:
            print('未上传任何文件')


# 处理图像
def process_handler(_):
    global IMAGE_FILE
    with output_area:
        if IMAGE_FILE is None:
            print("请先上传图片")
            return

        try:
            # 调用处理函数
            process_image(IMAGE_FILE)
        except Exception as e:
            import traceback
            print(f"处理图像时出错: {e}")
            print("详细错误信息:")
            traceback.print_exc()

# 修改后的处理函数，增加错误处理
def process_image(image_path):
    try:
        image, results = detector(image_path)

        # 画出关键点
        annotated_image_path = drawer_crop(image, results)
        display_resized(annotated_image_path, 200)

        if not results.hand_landmarks:
            print("没有检测到手势")
            return

        # knn分类
        prediction, max_probs, time_used = classifier(results)

        # 输出预测结果
        print(f"预测的手势类别：{prediction}，预测置信度：{max_probs:.4f}，预测耗时：{time_used:.6f} 秒")

    except Exception as e:
        print(f"处理函数内部出错: {e}")
        import traceback
        traceback.print_exc()

# 创建上传按钮
upload_button = widgets.Button(
    description='上传图片',
    button_style='primary',
    icon='upload'
)
upload_button.on_click(lambda _: upload_files())

# 创建识别按钮
recognize_button = widgets.Button(
    description='识别手势',
    button_style='success',
    icon='check'
)
recognize_button.on_click(process_handler)
# 显示上传按钮
display(upload_button)
display(output_area)