# 使用感知器识别所有数字

## 加载数据

In [1]:
import pylab
import numpy as np
import pickle
from ipywidgets import interact, interactive, fixed
import ipywidgets as widgets
import os
import gzip

In [2]:
np.random.seed(1)
import random

In [3]:
# 解压并加载数据集
with gzip.open('../data/mnist.pkl.gz', 'rb') as mnist_pickle:
    u = pickle._Unpickler(mnist_pickle)
    u.encoding = 'latin1'
    MNIST = u.load()

In [4]:
# 所有灰度数据归一化
features = MNIST['Train']['Features'].astype(np.float32) / 256.0
# 所有的label
labels = MNIST['Train']['Labels']

准备10组权重值，分别代表10个数字。
10组权重值的训练方式都是，如果是对应的数字则视为 pos 非对应的数字视为 neg。

In [5]:
weightsList = []

定义一个方法，可以将训练数据分类。

In [6]:
# 设置一个数字为 pos 其他所有数据都是 neg
def set_mnist_pos_neg(positive_label):
    # pos 和 neg 数据的序号
    positive_indices = [i for i, j in enumerate(MNIST['Train']['Labels']) 
                          if j == positive_label]
    negative_indices = [i for i, j in enumerate(MNIST['Train']['Labels']) 
                          if j != positive_label]
    # pos 和 neg 数据的图像
    positive_images = MNIST['Train']['Features'][positive_indices]
    negative_images = MNIST['Train']['Features'][negative_indices]
    
    return positive_images, negative_images

定义训练方法

In [7]:
# 训练方法
def train_graph(positive_examples, negative_examples, num_iterations = 100):
    # 数据维度, 此例是 28 * 28 = 784
    num_dims = positive_examples.shape[1]
    # 权重值 全0数组 shape = (784, 1)
    weights = np.zeros((num_dims, 1)) # initialize weights

    for i in range(num_iterations):
        pos = random.choice(positive_examples)
        neg = random.choice(negative_examples)

        z = np.dot(pos, weights)   
        if z < 0:
            weights = weights + pos.reshape(weights.shape)

        z  = np.dot(neg, weights)
        if z >= 0:
            weights = weights - neg.reshape(weights.shape)

    return weights

In [8]:
# 将 10 个数字各训练1000次
for x in range(10):
    _pos, _neg = set_mnist_pos_neg(x)
    weightsList.append(train_graph(_pos, _neg, num_iterations = 10000))

In [9]:
# 制作一个使用训练完的weight, 根据输入的图像来输出数字的函数
def img2Num(img):
    for x in range(10):
        if np.dot(img, weightsList[x]) > 0:
            return x
    return 0

In [10]:
# 进行若干次测试并统计正确率
def test(count):
    _len = len(features)
    correct = 0.0
    for x in range(count):
        i = random.randint(0, _len - 1)
        # 随机选取一副图像
        img = features[i]
        label = labels[i]
        if img2Num(img) == label:
            correct += 1
    return correct / count

查看训练结果

In [11]:
print(test(1000))

0.747
