In [2]:
import numpy as np
import matplotlib.pyplot as plt

Todo:
- normalize data
- init W differently

In [3]:
def unpickle(file):
    import pickle
    with open(file, 'rb') as fo:
        dict = pickle.load(fo, encoding='bytes')
    return dict

def load_data():
    """
    data : (10 000 x 3072)
    labels : (10 000 x 1)
    one_hot : (10 000 x 10)
    """
    
    file = unpickle("data/cifar-10-batches-py/data_batch_1")
    labels = file[b'labels']
    data = file[b'data']
    no_classes = 10
    N = len(labels)
    
    one_hot = np.zeros((N, no_classes))
    one_hot[np.arange(N), labels] = 1
    
    labels = np.array(labels).reshape(-1,1)
    
    # normalize
    mean = np.mean(data, axis=0, keepdims=True)
    std = np.std(data, axis=0, keepdims=True)
    
    X = (data - mean) / std
    return X, labels, one_hot, mean, std

def ReLU(x):
    return np.where(x > 0, x, 0)

def dReLU(x):
    return np.where(x > 0, 1, 0)

def softmax(x):
    e_x = np.exp(x - np.max(x))
    return e_x / e_x.sum(axis=0)

def plot_image(x, mean, std):
    x = (x.T * std + mean).astype(int)
    img = x.reshape(3,32,32)
    plt.figure(figsize=(2,2))
    plt.imshow(np.transpose(img, (1,2,0)))

In [None]:
def train():
    """basic with one example"""
    np.random.seed(40)
    
    data, labels, one_hots, mean, std = load_data()
    
    one_hot = one_hots[0,:].reshape(1,10)
    
    x = data[0,:].reshape(1,-1)
    x = x.T
    
    assert(x.shape == (3072,1))
    
    W1 = np.random.rand(10,3072) * 0.001
    b1 = np.zeros((10,1))
    W2 = np.random.rand(10,10) * 0.001
    b2 = np.zeros((10,1))
    
    for iter in range(100):
        # forward
        Z1 = np.dot(W1, x) + b1
        A1 = ReLU(Z1)

        Z2 = np.dot(W2, A1) + b2
        A2 = softmax(Z2)

        # Loss (over all examples, get average)
        L = np.dot(-one_hot, A2)
        
        # print("y_hat: {}".format(A2.T))
        print("loss: {}".format(L[0,0]))

        # backward
        dL_dZ2 = one_hot.T - A2 # y - y_hat

        dL_dW2 = np.dot(dL_dZ2, A1.T) # (10,10)
        dL_db2 = dL_dZ2

        dL_dA1 = np.dot(W2.T, dL_dZ2)
        dL_dZ1 = dL_dA1 * dReLU(Z1)

        dL_dW1 = np.dot(dL_dZ1, x.T)
        dL_db1 = dL_dZ1

        # update
        lr = 0.001

        W1 = W1 - lr * dL_dW1
        b1 = b1 - lr * dL_db1
        W2 = W2 - lr * dL_dW2
        b2 = b2 - lr * dL_db2        


train()

In [33]:
# All input
# Update after each epoch
np.random.seed(40)

def check_if_correct(y, p):
    if np.argmax(y) == np.argmax(p):
        return 1
    return 0

def train():      
    data, labels, one_hots, mean, std = load_data()
    data_size = data.shape[0]
    lr = 0.2
    
    W1, W2 = np.random.rand(10, 3072) * 0.0001, np.random.rand(10, 10) * 0.0001
    b1, b2 = np.zeros((10, 1)), np.zeros((10, 1))
    
    for epoch in range(101):
        accuracy = 0
        epoch_loss = 0
        epoch_dW1, epoch_dW2 = np.zeros((10, 3072)), np.zeros((10, 10)) 
        epoch_db1, epoch_db2 = np.zeros((10, 1)), np.zeros((10, 1))
        
        for index in range(data_size):
            x = data[index,:].reshape(1,-1)
            x = x.T
            
            one_hot = one_hots[index,:].reshape(1,10)
            
            # forward
            Z1 = np.dot(W1, x) + b1
            A1 = ReLU(Z1)          
            
            Z2 = np.dot(W2, A1) + b2            
            A2 = softmax(Z2)
            
            # Log loss
            L = -np.dot(one_hot, np.log(1e-15 + A2))
            epoch_loss += L[0,0]
            
            # L2 = -np.sum(one_hot * np.log(1e-15 + A2.T))
            
            accuracy += check_if_correct(one_hot, A2)
            
            # backward
            dZ2 = A2 - one_hot.T # p - y
            
            dW2 = np.dot(dZ2, A1.T)
            db2 = dZ2
            
            dA1 = np.dot(W2.T, dZ2)
            dZ1 = dA1 * dReLU(Z1)
            
            dW1 = np.dot(dZ1, x.T)
            db1 = dZ1
            
            epoch_dW1 += dW1
            epoch_dW2 += dW2
            epoch_db1 += db1
            epoch_db2 += db2
        
        # update
        epoch_dW1 /= data_size
        epoch_dW2 /= data_size
        epoch_db1 /= data_size
        epoch_db2 /= data_size
        
        accuracy /= data_size
        epoch_loss /= data_size

        W1 = W1 - lr * epoch_dW1
        b1 = b1 - lr * epoch_db1
        W2 = W2 - lr * epoch_dW2
        b2 = b2 - lr * epoch_db2
        
        if (epoch % 5 == 0):
            print("epoch: {} \tloss: {:.5} \tacc: {:.3}".format(epoch, epoch_loss, accuracy))

train()

0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0


KeyboardInterrupt: 

Best results:


In [23]:
def test():
    a = np.array([1,2,3,4,5,6]).reshape(2,3)
    b = np.array([9,10,9,10,10,10,11,10,11]).reshape(3,3)
    
    a = a.reshape(-1,1)
    a = a.reshape(1,-1)
    
    print(np.argmax(a))
    
test()

5


-34.538776394910684