# サポート・ベクトル・マシン

In [1]:
import os
import struct
import numpy as np
import cv2
import matplotlib.pyplot as plt
from sklearn import svm

In [2]:
# MNISTのファイル (あらかじめダウンロードしておく)
train_image_file = 'mnist/train-images-idx3-ubyte'
train_label_file = 'mnist/train-labels-idx1-ubyte'
test_image_file = 'mnist/t10k-images-idx3-ubyte'
test_label_file = 'mnist/t10k-labels-idx1-ubyte'

In [3]:
def load_images(filename):
    """ MNISTの画像データを読み込む """

    fp = open(filename, 'rb')
    
    # マジックナンバー
    magic = struct.unpack('>i', fp.read(4))[0]
    if magic != 2051:
        raise Exception('Invalid MNIST file!')
        
    # 各種サイズ
    n_images, height, width = struct.unpack('>iii', fp.read(4 * 3))
    
    # 画像の読み込み
    total_pixels = n_images * height * width
    images = struct.unpack('>' + 'B' * total_pixels, fp.read(total_pixels))
    
    images = np.asarray(images, dtype='uint8')
    images = images.reshape((n_images, height, width, 1))
    
    # 値の範囲を[0, 1]に変更する
    images = images.astype('float32') / 255.0
    
    fp.close()
    
    return images

In [4]:
def load_labels(filename):
    """ MNISTのラベルデータを読み込む """

    fp = open(filename, 'rb')
    
    # マジックナンバー
    magic = struct.unpack('>i', fp.read(4))[0]
    if magic != 2049:
        raise Exception('Invalid MNIST file!')
        
    # 各種サイズ
    n_labels = struct.unpack('>i', fp.read(4))[0]
    
    # ラベルの読み込み
    labels = struct.unpack('>' + 'B' * n_labels, fp.read(n_labels))
    labels = np.asarray(labels, dtype='int32')
    
    fp.close()
    
    return labels

In [5]:
def to_onehot(labels):
    """ one-hot形式への変換 """
    return np.identity(10)[labels]

In [6]:
images = load_images(train_image_file)
labels = load_labels(train_label_file)
onehot = to_onehot(labels)

In [None]:
n = len(images)
X = images.reshape((n, -1))
y = labels.reshape((n))
clf = svm.NuSVC(max_iter=20)
clf.fit(X, y)

In [None]:
test_images = load_images(test_image_file)
test_labels = load_labels(test_label_file)
pred_labels = clf.predict(test_images.reshape(10000, -1))
acc = (pred_labels == test_labels).mean()
print('Accuracy: {:.4f}'.format(acc))