In [2]:
import struct
import os
import random
import math
import numpy as np
import matplotlib.pyplot as plt
import pickle

np.random.seed(3)

In [16]:
def decode_labels(file):
    with open(file, 'rb') as f:
        binary_data = f.read()
        _, num_items = struct.unpack_from('>II', binary_data, 0)
        labels       = struct.unpack_from('B'*num_items, binary_data, 8)
        return np.array(labels).reshape(-1, 1).astype(np.int)

def decode_images(file):
    with open(file, 'rb') as f:
        binary_data = f.read()
        _,num_images, rows, cols = struct.unpack_from('>IIII', binary_data, 0)
        images                   = struct.unpack_from('B'*(num_images*rows*cols), binary_data, 16)
        return np.array(images).reshape(-1, rows*cols)

filepath = [ "../stage_1/data/mnist/train-images-idx3-ubyte",
             "../stage_1/data/mnist/train-labels-idx1-ubyte",
             "../stage_1/data/mnist/t10k-images-idx3-ubyte",
             "../stage_1/data/mnist/t10k-labels-idx1-ubyte"]

train_images = decode_images(filepath[0])
train_labels = decode_labels(filepath[1])
test_images = decode_images(filepath[2])
test_labels = decode_labels(filepath[3])

In [17]:
class Dataset:
    def __getitem__(self, index):
        raise NotImplementedError()
    def __len__(self):
        raise NotImplementedError()

class MNISTDataset(Dataset):
    def __init__(self, image_file, label_file):
        self.num_classes = 10
        self.images = decode_images(image_file)
        self.labels = decode_labels(label_file)
        self.images = (self.images / 255.0 - 0.5).astype(np.float32)
        # pytorch里面， CELoss使用的不是one_hot
        self.labels_one_hot = np.eye(self.num_classes)[self.labels,:]
        
    def __getitem__(self, index):
        return self.images[index], self.labels[index], self.labels_one_hot[index]
    
    def __len__(self):
        return len(self.images)
    
    def __iter__(self):
        return DataLoader(self)

class DataLoader:
    def __init__(self, dataset, batch_size, shuffle=True):
        self.dataset = dataset
        self.shuffle = shuffle
        self.batch_size = batch_size
    
    def __iter__(self):
        return DataLoaderIterator(self)
    
    def __len__(self):
        '''
        返回一共可以传回多少轮（以每轮batch_size大小为依据）
        这里用drop last方案
        '''
        return len(self.dataset) // self.batch_size
    
class DataLoaderIterator:
    '''
    负责一轮数据的打乱，封装
    '''
    def __init__(self, dataloader):
        self.dataloader = dataloader
        self.dataset = dataloader.dataset
        self.cursor = 0
        self.indexs = list(range(len(self.dataset)))
        if self.dataloader.shuffle:
            np.random.shuffle(self.indexs)
        
    def __next__(self):
        drop_last = len(self.dataset) % self.dataloader.batch_size # drop掉的不需要
        if self.cursor >= len(self.dataset) - drop_last:
            raise StopIteration()
        batch_data = []
        for i in range(self.dataloader.batch_size):
            index = self.indexs[self.cursor]
            item = self.dataloader.dataset[index]
            batch_data.append(item)
            self.cursor += 1
        output = list(zip(*batch_data))  # 此时变成了3组，一每组是(n,)
        for i in range(len(output)):
            output[i] = np.vstack(output[i]) # 此时把每组里的(n,)变成了(n,1)
        return output

class Parameter:
    def __init__(self, data):
        self.data = data
        self.grad = np.zeros_like(data)
        
    def zero_grad(self):
        self.grad[...] = 0

class Module:
    '''需求：
    backwark, forward
    training status
    '''
    def __init__(self):
        self.training = True
        
    def forward(self, *args):
        raise NotImplementedError
    
    def backward(self, g):
        raise NotImplementedError
        
    def __call__(self, *args):
        return self.forward(*args)
    
    def modules(self):
        ms = []
        for attr in self.__dict__:
            m = self.__dict__[attr]
            if isinstance(m, Module):
                ms.append(m)
        return ms
    
    def params(self):
        ps = []
        for attr in self.__dict__:
            p = self.__dict__[attr]
            if isinstance(p, Parameter):
                ps.append(p)
                
        for m in self.modules():
            ps.extend(m.params())
        
        return ps
    
    def train(self):
        self.training = True
        for m in self.modules():
            m.train()
        return self
    
    def eval(self):
        self.training = False
        for m in self.modules():
            m.eval()
        return self
    

class Linear(Module):
    def __init__(self, num_input, num_output):
        super().__init__()
        # parameters with kaiming init
        self.weight = Parameter(np.random.normal(0, 1/np.sqrt(num_input), size=(num_input, num_output)))
        self.bias = Parameter(np.zeros((1, num_output)))
    
    def forward(self, x):
        self.x = x
        return x @ self.weight.data + self.bias.data
    
    def backward(self, g):
        # update
        dw = self.x.T @ g
        db = np.sum(g, 0, keepdims=True)
        self.weight.grad += dw
        self.bias.grad += db
        # pass back
        dx = g @ self.weight.data.T
        return dx
        
        
class Sigmoid(Module):
    def sigmoid_fn(self,x):  # sigmoid 函数
        xtemp = x.copy()
        epx = 0.0001
        p = xtemp < 0
        p1 = xtemp >= 0
        xtemp[p] = np.exp(xtemp[p]) / (np.exp(xtemp[p]) + 1)
        xtemp[p1] = 1 / (1 + np.exp(-xtemp[p1]))
        return np.clip(xtemp, a_min=epx, a_max=1 - epx)
    
    def forward(self, x):
        self.y = self.sigmoid_fn(x)
        return self.y
    
    def backward(self, g):
        return g * self.y * (1-self.y)

class Sequencial(Module):
    def __init__(self, *items):
        super().__init__()
        self.items = items
        
    def modules(self):
        return self.items
    
    def forward(self, x):
        for m in self.items:
            x = m(x)
        return x
    
    def backward(self, g):
        for m in self.items[::-1]:
            g = m.backward(g)
        return g
    
class SoftmaxCrossEntropyLoss(Module):
    def __init__(self):
        super().__init__()
        
    def forward(self, x, y):
        '''
        f(x) = e^x / sum(e^x)
        '''
        self.y = y
        ex = np.exp(x)
        sumex = np.sum(ex, 1, keepdims=True)
        self.m = x.shape[0]
        self.p = ex / sumex
        J = -np.sum(self.y * np.log(self.p)) / self.m
        return J
        
    def backward(self, g):
        return g * (self.p - self.y) / self.m
    
class Network(Module):
    def __init__(self, num_feature, num_hidden, num_classes):
        super().__init__()
        self.layers = Sequencial(
            Linear(num_feature, num_hidden),
            Sigmoid(),
            Linear(num_hidden, num_classes)
        )
        self.loss = SoftmaxCrossEntropyLoss()
        
    def inference(self, x):
        return self.layers(x)


    def forward(self, x, target):
        return self.loss(self.inference(x), target)

    def backward(self, g=1):
        g = self.loss.backward(g)
        return self.layers.backward(g)
        
class Optimizer:
    def __init__(self, params, lr):
        self.params = params
        self.lr = lr
        
    def zero_grad(self):
        for p in self.params:
            p.zero_grad()
            
    def set_lr(self, lr):
        self.lr = lr
        
    def step(self):
        raise NotImplementedError()

class SGD(Optimizer):
    def __init__(self, params, lr=1e-3):
        super().__init__(params, lr)
        
    def step(self):
        for p in self.params:
            p.data -= self.lr * p.grad

In [20]:
# dataset = MNISTDataset("../stage_1/data/mnist/t10k-images-idx3-ubyte",
#              "../stage_1/data/mnist/t10k-labels-idx1-ubyte")
# loader = DataLoader(dataset, 256)
# for i, (a,b,c) in enumerate(loader):
#     print(i, a.shape, b.shape, c.shape)

In [21]:
# image, label, one_hot = dataset[5]
# image.shape, label, one_hot

In [22]:
# plt.imshow(image.reshape(28,28))

In [23]:
batch_size = 256
num_hidden = 128
num_feature = 784
num_classes = 10
num_epoch = 10

train_data = MNISTDataset("../stage_1/data/mnist/train-images-idx3-ubyte",
             "../stage_1/data/mnist/train-labels-idx1-ubyte")
train_loader = DataLoader(train_data, batch_size, True)
test_data = MNISTDataset("../stage_1/data/mnist/t10k-images-idx3-ubyte",
             "../stage_1/data/mnist/t10k-labels-idx1-ubyte")
test_loader = DataLoader(test_data, 512, True)

In [28]:
network = Network(num_feature, num_hidden, num_classes)
optim = SGD(network.params(), 0.8)

for epoch in range(num_epoch):
    for rounds, (imgs, _, one_hots) in  enumerate(train_loader):
        niter = epoch * len(train_loader) + rounds  # 每训练完batch_size张图片算1轮，每一个epoch有len(train_load)轮
        loss = network(imgs, one_hots)
        
        optim.zero_grad()
        network.backward()
        optim.step()
        
        if niter % 100 == 0 and niter > 0:
            progress = epoch + rounds / len(train_loader)
            correct = 0
            for imgs, lbls, _ in test_loader:
                probility = network.inference(imgs)
                labels = np.argmax(probility, axis=1)
                correct += np.sum(labels == lbls[:,0]) 
            accuracy = correct / len(test_data)
            print(f'epoch: {progress:.2f}/{num_epoch}, loss: {loss:.6f} accuracy: {accuracy*100:.2f}%')

epoch: 0.43/10, loss: 0.538022 accuracy: 80.43%
epoch: 0.85/10, loss: 0.442779 accuracy: 86.44%
epoch: 1.28/10, loss: 0.336532 accuracy: 88.32%
epoch: 1.71/10, loss: 0.392755 accuracy: 88.50%
epoch: 2.14/10, loss: 0.221910 accuracy: 89.16%
epoch: 2.56/10, loss: 0.293324 accuracy: 90.05%
epoch: 2.99/10, loss: 0.278882 accuracy: 90.56%
epoch: 3.42/10, loss: 0.307788 accuracy: 89.66%
epoch: 3.85/10, loss: 0.206971 accuracy: 90.89%
epoch: 4.27/10, loss: 0.233365 accuracy: 90.69%
epoch: 4.70/10, loss: 0.171882 accuracy: 91.49%
epoch: 5.13/10, loss: 0.208088 accuracy: 91.91%
epoch: 5.56/10, loss: 0.162880 accuracy: 92.13%
epoch: 5.98/10, loss: 0.178720 accuracy: 92.32%
epoch: 6.41/10, loss: 0.216926 accuracy: 92.29%
epoch: 6.84/10, loss: 0.178971 accuracy: 92.43%
epoch: 7.26/10, loss: 0.194361 accuracy: 92.69%
epoch: 7.69/10, loss: 0.183625 accuracy: 92.83%
epoch: 8.12/10, loss: 0.123009 accuracy: 92.74%
epoch: 8.55/10, loss: 0.110284 accuracy: 92.96%
epoch: 8.97/10, loss: 0.117449 accuracy: