In [1]:
import numpy as np
import matplotlib.pyplot as plt
import time

### Implementing classes and real network

In [2]:
def unpickle(file):
    import pickle
    with open(file, 'rb') as fo:
        dict = pickle.load(fo, encoding='bytes')
    return dict

def load_data():
    """
    data : (10 000 x 3072)
    labels : (10 000 x 1)
    one_hot : (10 000 x 10)
    """
    
    file = unpickle("data/cifar-10-batches-py/data_batch_1")
    labels = file[b'labels']
    data = file[b'data']
    no_classes = 10
    N = len(labels)
    
    one_hot = np.zeros((N, no_classes))
    one_hot[np.arange(N), labels] = 1
    
    labels = np.array(labels).reshape(-1,1)
    
    # normalize
    mean = np.mean(data, axis=0, keepdims=True)
    std = np.std(data, axis=0, keepdims=True)
    
    X = (data - mean) / std
    return X, labels, one_hot, mean, std

def ReLU(x):
    return np.where(x > 0, x, 0)

def dReLU(x):
    return np.where(x > 0, 1, 0)

def softmax(x):
    e_x = np.exp(x - np.max(x))
    return e_x / e_x.sum(axis=0)

def plot_image(x, mean, std):
    x = (x.T * std + mean).astype(int)
    img = x.reshape(3,32,32)
    plt.figure(figsize=(2,2))
    plt.imshow(np.transpose(img, (1,2,0)))

def check_if_correct(y, p):
    temp = np.argmax(y, axis=1) - np.argmax(p.T, axis=1)
    correct_ones = np.where(temp == 0, 1, 0)
    return np.sum(correct_ones)

class Plotter:
    def __init__(self, title):
        self.title = title
        self.x = []
        self.y = []
        
    def add(self, epoch, cost):
        self.y.append(cost)
        self.x.append(epoch)
    
    def plot(self):
        fig, ax = plt.subplots()
        ax.plot(self.x, self.y)

        ax.set(xlabel="epochs", ylabel="cost", title=self.title)
        ax.grid()
        # fig.savefig("{}.png".format(self.title))
        plt.show()


In [None]:
batch_size = 100
w_decay = 0.01
lr = 0.05

class Layer:
    def __init__(self, in_size, out_size):
        self.W = np.random.rand(out_size, in_size) * 0.0001
        self.b = np.zeros((out_size, 1))
        
        self.gamma = np.random.rand(out_size, 1)
        self.beta = np.zeros((out_size, 1))
        
    def forward(self, input):
        self.input = input
        S = np.dot(self.W, input) + self.b
        self.S = S
        self.mu = np.sum(S, axis=1).reshape(-1,1) / batch_size
        self.var = np.var(S, axis=1)
        
        self.S_hat = np.dot(np.diag( (self.var + 1e-15) ** (-0.5) ), (S - self.mu))
        self.S_t = self.gamma * self.S_hat + self.beta
        
        self.output = ReLU(self.S_t)
        return self.output
    
    def backprop(self, G):
        G = G * dReLU(self.output)
        self.dgamma = np.sum(G * self.S_hat, axis=1, keepdims=True) / batch_size
        self.dbeta = np.sum(G, axis=1, keepdims=True) / batch_size
        
        G = G * self.gamma
        
        sigma = self.var.reshape(-1,1) + 1e-15
        sigma1 = sigma ** -0.5
        sigma2 = sigma ** -1.5
        
        G1 = G * sigma1
        G2 = G * sigma2
        
        D = self.S - self.mu
        c = np.sum(G2 * D, axis=1, keepdims=True)
        part1 = (1 / batch_size) * (G1 @ np.ones((batch_size, 1))) @ np.ones((1, batch_size))
        part2 = (1 / batch_size) * D * (c @ np.ones((1, batch_size)))
        
        G = G1 - part1 - part2
        
        self.dW = np.dot(G, self.input.T) / batch_size
        self.dW += w_decay * 2 * self.W
        self.db = np.sum(G, axis=1).reshape(-1,1) / batch_size
        
        return np.dot(self.W.T, G)
    
    def update(self):
        self.W = self.W - lr * self.dW
        self.b = self.b - lr * self.b
        self.gamma = self.gamma - lr * self.dgamma
        self.beta = self.beta - lr * self.dbeta
        
class LastLayer:
    def __init__(self, in_size, no_classes):
        self.W = np.random.rand(no_classes, in_size) * 0.0001
        self.b = np.zeros((no_classes, 1))

    def forward(self, input):
        self.input = input
        S = np.dot(self.W, input) + self.b
        self.output = softmax(S)
        return self.output
    
    def backprop(self, G):
        self.dW = np.dot(G, self.input.T) / batch_size
        self.dW += 2 * w_decay * self.W
        self.db = np.sum(G, axis=1).reshape(-1,1) / batch_size
        
        return np.dot(self.W.T, G)

    def update(self):
        self.W = self.W - lr * self.dW
        self.b = self.b - lr * self.b


def train():
    plotter = Plotter("batch-norm-tester-".format(time.time()))
    data, _, one_hots, mean, std = load_data()
    data_size = data.shape[0]
    no_hidden_1 = 50
    no_hidden_2 = 30
    output_nodes = 10
    
    layer1 = Layer(3072, no_hidden_1)
    layer2 = Layer(no_hidden_1, no_hidden_2)
    layer3 = LastLayer(no_hidden_2, output_nodes)
    
    iterations = data_size // batch_size
    start_t = time.time()
    
    for epoch in range(200):
        avg_loss = 0
        accuracy = 0
        
        for idx in range(iterations):
            start = batch_size * idx
            end = batch_size * (idx + 1)
            
            # get X and Y
            X0 = data[start:end,:].T # 3072x100
            one_hot = one_hots[start:end,:] # 100x10
            
            X1 = layer1.forward(X0)
            X2 = layer2.forward(X1)
            output = layer3.forward(X2) # X3
            
            # Loss
            L2 = np.sum(layer1.W ** 2) + np.sum(layer2.W ** 2) + np.sum(layer3.W ** 2)
            loss = -np.sum(one_hot * np.log(1e-15 + output.T)) + w_decay * L2
            
            avg_loss += (loss / batch_size)
            accuracy += check_if_correct(one_hot, output)
            
            # Backward
            G = output - one_hot.T
            G = layer3.backprop(G)
            G = layer2.backprop(G)
            G = layer1.backprop(G)
            
            # Update
            layer1.update()
            layer2.update()
            layer3.update()
            
        avg_loss /= iterations
        accuracy /= iterations
        
        if (epoch % 10 == 0):
            plotter.add(epoch, avg_loss)
            print("epoch: {} \tloss: {:.3} \tacc: {:.3}".format(epoch, avg_loss, accuracy))

    print("{:.3} s".format(time.time() - start_t))
    plotter.plot()
        
train()


epoch: 0 	loss: 2.21 	acc: 10.4
epoch: 10 	loss: 1.61 	acc: 42.9
epoch: 20 	loss: 1.3 	acc: 55.0
