# 拇指项目手势识别（拇指向上，拇指向下）

- 数据训练、测试的入门交互式案例，通过PC\笔记本的摄像头采集数据，通过交互界面进行过程控制和参数调整。
- 技术：采用pyTorch AI框架、torchvision数据集库、ipywidgets Jupyter交互界面库。
- 任务：定义任务与分类，每个分类分2类thumbs_up：大拇指向上，thumbs_down：大拇指向下。

In [1]:
import torchvision.transforms as transforms
from dataset import ImageClassificationDataset

#模型训练任务名称
TASK = 'thumbs'
#数据存储目录
CATEGORIES = ['thumbs_up', 'thumbs_down']
#数据集划分目录
DATASETS = ['A', 'B']

#图形转化
TRANSFORMS = transforms.Compose([
    transforms.ColorJitter(0.2, 0.2, 0.2, 0.2),
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

#数据集
datasets = {}

#数据集名称(可以扩展，如果采集多个数据)
for name in DATASETS:
    datasets[name] = ImageClassificationDataset(TASK + '_' + name, CATEGORIES, TRANSFORMS)
    
#print("{} task(任务) with {} categories(目录) defined(已定义)".format(TASK, CATEGORIES))

print("任务{},分类{}已定义！".format(TASK, CATEGORIES))

任务thumbs,分类['thumbs_up', 'thumbs_down']已定义！


### 步骤 1.数据采集
- 借助 iPython 小组件，并使用摄像头为类别采集图像。交互界面进行数据采集。
- A数据集用于训练；B数据集用于测试；


In [2]:
import ipywidgets
import traitlets
from IPython.display import display
import threading
import cv2
from camera import Camera,bgr8_to_jpeg

#摄像头
camera = Camera()

dataset = datasets[DATASETS[0]]

#图片组件，展示摄像图
camera_widget = ipywidgets.Image(format='jpg',width=224,height=224)
traitlets.dlink((camera, 'value'), (camera_widget, 'value'), transform=bgr8_to_jpeg)

#数据集
dataset_widget = ipywidgets.Dropdown(options=DATASETS, description='数据集')
#分类采集
category_widget = ipywidgets.Dropdown(options=dataset.categories, description='分类')
#累计文本框
count_widget = ipywidgets.IntText(description='累计')
#保存文本框
save_widget = ipywidgets.Button(description='保存')
#数量记录
count_widget.value = dataset.get_count(category_widget.value)

#
def set_dataset(change):
    global dataset
    dataset = datasets[change['new']]
    count_widget.value = dataset.get_count(category_widget.value)
dataset_widget.observe(set_dataset, names='value')

# 更新数量
def update_counts(change):
    count_widget.value = dataset.get_count(change['new'])
category_widget.observe(update_counts, names='value')

# 保存图片，并且更新技术
def save(c):
    dataset.save_entry(camera.value, category_widget.value)
    count_widget.value = dataset.get_count(category_widget.value)
save_widget.on_click(save)

#组件
data_collection_widget = ipywidgets.VBox([
    ipywidgets.HBox([camera_widget]), dataset_widget, category_widget, count_widget, save_widget
])

display(data_collection_widget)

VBox(children=(HBox(children=(Image(value=b'\xff\xd8\xff\xe0\x00\x10JFIF\x00\x01\x01\x00\x00\x01\x00\x01\x00\x…

### 步骤2 定义神经网络

In [3]:
import torch
import torchvision


# 使用RESNET 18
model = torchvision.models.resnet18(pretrained=True)
model.fc = torch.nn.Linear(512, len(dataset.categories))


model_save_button = ipywidgets.Button(description='保存模型')
model_load_button = ipywidgets.Button(description='装载模式')
model_path_widget = ipywidgets.Text(description='目录路径', value='my_model.pth')
#装载模型
def load_model(c):
    model.load_state_dict(torch.load(model_path_widget.value))
model_load_button.on_click(load_model)
#保存模型    
def save_model(c):
    torch.save(model.state_dict(), model_path_widget.value)
model_save_button.on_click(save_model)
#组件
model_widget = ipywidgets.VBox([
    model_path_widget,
    ipywidgets.HBox([model_load_button, model_save_button])
])

display(model_widget)
# print("model configured and model_widget created")

VBox(children=(Text(value='my_model.pth', description='目录路径'), HBox(children=(Button(description='装载模式', style…

### 实时执行

In [4]:
import threading
import time
from utils import preprocess
import torch.nn.functional as F

state_widget = ipywidgets.ToggleButtons(options=['stop', 'live'], description='状态', value='stop')
prediction_widget = ipywidgets.Text(description='预测')
score_widgets = []
for category in dataset.categories:
    score_widget = ipywidgets.FloatSlider(min=0.0, max=1.0, description=category, orientation='vertical')
    score_widgets.append(score_widget)

def live(state_widget, model, camera, prediction_widget, score_widget):
    global dataset
    while state_widget.value == 'live':
        image = camera.value
        preprocessed = preprocess(image)
        output = model(preprocessed)
        output = F.softmax(output, dim=1).detach().cpu().numpy().flatten()
        category_index = output.argmax()
        prediction_widget.value = dataset.categories[category_index]
        for i, score in enumerate(list(output)):
            score_widgets[i].value = score
            
def start_live(change):
    if change['new'] == 'live':
        execute_thread = threading.Thread(target=live, args=(state_widget, model, camera, prediction_widget, score_widget))
        execute_thread.start()

state_widget.observe(start_live, names='value')

live_execution_widget = ipywidgets.VBox([
    ipywidgets.HBox(score_widgets),
    prediction_widget,
    state_widget
])

display(live_execution_widget)

VBox(children=(HBox(children=(FloatSlider(value=0.0, description='thumbs_up', max=1.0, orientation='vertical')…

### 训练

In [5]:
BATCH_SIZE = 8

optimizer = torch.optim.Adam(model.parameters())
# optimizer = torch.optim.SGD(model.parameters(), lr=1e-3, momentum=0.9)

epochs_widget = ipywidgets.IntText(description='次数', value=1)
eval_button = ipywidgets.Button(description='评估')
train_button = ipywidgets.Button(description='训练')
loss_widget = ipywidgets.FloatText(description='损失')
accuracy_widget = ipywidgets.FloatText(description='准确率')
progress_widget = ipywidgets.FloatProgress(min=0.0, max=1.0, description='进度')

def train_eval(is_training):
    global BATCH_SIZE, LEARNING_RATE, MOMENTUM, model, dataset, optimizer, eval_button, train_button, accuracy_widget, loss_widget, progress_widget, state_widget
    
    try:
        train_loader = torch.utils.data.DataLoader(
            dataset,
            batch_size=BATCH_SIZE,
            shuffle=True
        )

        state_widget.value = 'stop'
        train_button.disabled = True
        eval_button.disabled = True
        time.sleep(1)

        if is_training:
            model = model.train()
        else:
            model = model.eval()
        while epochs_widget.value > 0:
            i = 0
            sum_loss = 0.0
            error_count = 0.0
            for images, labels in iter(train_loader):
                # send data to device
#                 images = images.to(device)
#                 labels = labels.to(device)

                if is_training:
                    # zero gradients of parameters
                    optimizer.zero_grad()

                # execute model to get outputs
                outputs = model(images)

                # compute loss
                loss = F.cross_entropy(outputs, labels)

                if is_training:
                    # run backpropogation to accumulate gradients
                    loss.backward()

                    # step optimizer to adjust parameters
                    optimizer.step()

                # increment progress
                error_count += len(torch.nonzero(outputs.argmax(1) - labels).flatten())
                count = len(labels.flatten())
                i += count
                sum_loss += float(loss)
                progress_widget.value = i / len(dataset)
                loss_widget.value = sum_loss / i
                accuracy_widget.value = 1.0 - error_count / i
                
            if is_training:
                epochs_widget.value = epochs_widget.value - 1
            else:
                break
    except e:
        pass
    model = model.eval()

    train_button.disabled = False
    eval_button.disabled = False
    state_widget.value = 'live'
    
train_button.on_click(lambda c: train_eval(is_training=True))
eval_button.on_click(lambda c: train_eval(is_training=False))
    
train_eval_widget = ipywidgets.VBox([
    epochs_widget,
    progress_widget,
    loss_widget,
    accuracy_widget,
    ipywidgets.HBox([train_button, eval_button])
])

display(train_eval_widget)

VBox(children=(IntText(value=1, description='次数'), FloatProgress(value=0.0, description='进度', max=1.0), FloatT…

In [6]:
all_widget = ipywidgets.VBox([
    ipywidgets.HBox([data_collection_widget, live_execution_widget]), 
    train_eval_widget,
    model_widget
])

display(all_widget)

VBox(children=(HBox(children=(VBox(children=(HBox(children=(Image(value=b'\xff\xd8\xff\xe0\x00\x10JFIF\x00\x01…